49.3_数据的连接

49.3 数据的连接

训练神经网络时会从数据集中取出一部分数据作为小批量数据。下面是使用Spiral类取出小批量数据的代码。

train_set =dezero.datasets.Spiral()   
batch_index  $= [0,1,2]$  #取出第0个~第2个数据   
batch  $=$  [train_set[i]for i in batch_index] #batch  $= [(data_{-}0$  ,label_0),(data_1,label_1),(data_2,label_2)]

上面的代码首先通过索引操作提取多个数据(小批量数据)。代码中的batch是由多个数据组成的列表。要作为DeZero的神经网络的输入,我们还需将这些数据转换为ndarray实例。下面是用于转换的代码。

$\mathbf{x} =$  np.array([example[0] for example in batch])   
t  $=$  np.array([example[1] for example in batch])   
print(x.shape)   
print(t.shape)

运行结果

(3,2) (3,)

上面的代码从batch的每个元素中提取数据(或标签),并将它们转换(连接)为一个ndarray实例。这样,这些数据就可以作为神经网络的输入使用了。