51.1_MNIST数据集
51.1 MNIST数据集
DeZero在dezero/datasets.py中提供了MNIST类,这个MNIST类继承了Dataset类。它的使用示例如下所示。
importdezero
train_set $\equiv$ dezero.datasets.MNIST(train=True,transform $=$ None) test_set $\equiv$ dezero.datasets.MNIST(train $\equiv$ False,transform $\equiv$ None) print(len(train_set))
print(len(test_set))运行结果
60000
10000上面的代码分别获取了用于训练的数据和用于测试的数据。代码中通过设置transform=None来(显式地)指定不对数据进行预处理。之后查看了用于训练的数据(train_set)和用于测试的数据(test_set)的数据长度。结果是train_set为60000,test_set为10000。也就是说,有60000个训练数据和10000个测试数据。接下来运行以下代码。
x, t = train_set[0]
print(type(x), x.shape)
print(t)运行结果
<class 'numpy.ndarray'> (1, 28, 28)
5上面的代码从train_set中抽取了第0个样本数据。取出的MNIST数据集的数据形式为(data, label),即包含data(图像)和标签的元组。另外,MNIST的输入数据的形状是(1,28,28),这意味着图像数据是1个通道(灰度)、 像素的数据。标签是作为正确答案的数字的索引 。下面
尝试对数据执行可视化操作。
import matplotlib.pyplot as plt数据示例
x,t $=$ train_set[0] #取出第0个(data,label)plt.imshow(x.reshape(28,28),cmap $\equiv$ 'gray')plt.axis('off')plt.show()print('label:',t)运行结果
5
运行上面的代码,我们会看到图51-2所示的图像。

图51-2 MINST的图像示例
下面使用神经网络来训练这个手写图像的数据。在训练之前,我们需要对输入数据进行预处理。代码如下所示。
def f(x):
x = x Flatten()
x = x.astype(np.float32)
x /= 255.0
return x
train_set =dezero.datasets.MNIST(train=True,transform=f)
test_set =dezero.datasets.MNIST(train=False,transform=f)首先将输入数据排成一列。这样,输入数据的形状就从(1, 28, 28)转换为(784,)。然后将数据的类型转换为np.float32(32位浮点数)。最后除以255.0,将输入数据转换为0.0和1.0之间的数值。这些预处理在MNIST类中是默认进行的。因此,编写dezerodatasets.MNIST(train=True)也会执行上述预处理(dezero/datasets.py包含使用dezero/transfers.py中的类进行预处理的代码)。