50.2_使用 DataLoader

50.2 使用 DataLoader

现在我们来用一下 DataLoader 类。使用这个 DataLoader 类可以轻松取出小批量数据。下面是在神经网络训练的场景使用 DataLoader 的示例,代码如下所示。

fromdezero.datasetsimport Spiral   
fromdezero import DataLoader   
batch_size  $= 10$    
max_epoch  $= 1$    
train_set  $=$  Spiral(train=True)   
test_set  $=$  Spiral(train=False)   
trainloader  $=$  DataLoader(train_set,batch_size)   
testloader  $=$  DataLoader(test_set,batch_size,shuffle  $\equiv$  False)   
for epoch in range(max_epoch): forx,tin trainloader: print(x.shape,t.shape)#x、t是训练数据 break #在每轮训练结束时取出测试数据 forx,tin testloader: print(x.shape,t.shape)#x、t是测试数据 break

运行结果

(10,2) (10,)  
(10,2) (10,)

上面的代码创建了两个 DataLoader,分别用于训练和测试。由于用于训练的 DataLoader 在每轮训练时要对数据进行重排,所以被设置为 shuffle=True(默认);用于测试的 DataLoader 只用于精度评估,所以被设置为 shuffle=False。设置好之后,DataLoader 就会进行小批量数据的取出和重排工作。

接下来,使用 DataLoader 类训练螺旋数据集。不过在此之前,我们再添加一个工具函数。