49.3_数据的连接
49.3 数据的连接
训练神经网络时会从数据集中取出一部分数据作为小批量数据。下面是使用Spiral类取出小批量数据的代码。
train_set =dezero.datasets.Spiral()
batch_index $= [0,1,2]$ #取出第0个~第2个数据
batch $=$ [train_set[i]for i in batch_index] #batch $= [(data_{-}0$ ,label_0),(data_1,label_1),(data_2,label_2)]上面的代码首先通过索引操作提取多个数据(小批量数据)。代码中的batch是由多个数据组成的列表。要作为DeZero的神经网络的输入,我们还需将这些数据转换为ndarray实例。下面是用于转换的代码。
$\mathbf{x} =$ np.array([example[0] for example in batch])
t $=$ np.array([example[1] for example in batch])
print(x.shape)
print(t.shape)运行结果
(3,2) (3,)上面的代码从batch的每个元素中提取数据(或标签),并将它们转换(连接)为一个ndarray实例。这样,这些数据就可以作为神经网络的输入使用了。