49.5_数据集的预处理

49.5 数据集的预处理

在向机器学习的模型输入数据之前,通常要对数据进行一定的处理,比如从数据中减去某个值,或者改变数据的形状。常见的处理还有数据增强,即通过旋转或翻转图像等方式来人为地增加数据。为了支持这些预处理(以及数据增强),我们向Dataset类中添加以下实现预处理功能的代码。

dezero/datasets.py

class Dataset: def__init__(self,train  $\equiv$  True,transform  $\equiv$  None,target_transform  $\equiv$  None): self.train  $=$  train self.transform  $=$  transform self.target_transform  $=$  target_transform if self.transform is None: self.transform  $=$  lambda x:x if self.target_transform is None: self.target_transform  $=$  lambda x:x self.data  $=$  None self.label  $=$  None self.prepare()   
def__getitem_(self,index): assert np.isscalar(index) if self.label is None: return self.transform(self.data[index]),None else: return self.transform(self.data[index]),\,\, $\mathrm{sim}$  .target_transform(self.label[index])   
def__len__(self): return len(self.data)   
defprepare(self): pass

上面的代码在初始化阶段接收transform和target_transform等新的参数。这些参数是可调用的对象(如Python的函数等)。transform对单个输入数据进行转换处理,target_transform对单个标签进行转换处理。如果传来的参数是None,预处理则会被设置为lambda x:x,lambda表达式可以直接返回参数(即不做预处理)。有了设置预处理的功能后,我们可以写出以下代码。

def f(x):
    y = x / 2.0
    return y
train_set =dezero.datasets.Spiral(transform  $\equiv$  f)

上面的示例代码针对输入数据执行了将其缩放为一半的预处理。用户可以像这样对数据集添加任何预处理。DeZero在dezero/transforms.py中内置了常用的预处理转换,比如数据正则化处理及与图像数据(PIL.Image实例)相关的转换处理。下面是一个实际的使用示例。

fromdezero import transforms   
f  $=$  transforms.Normalize(mean  $\coloneqq$  0.0,std  $\coloneqq$  2.0)   
train_set  $=$ dezero.datasets.Spiral(transforms  $\equiv$  f)

如果输入是x,那么上面代码中的transform.Normalize(mean=0.0, std=2.0)就会将x转换为(x - mean) / std。如果想连续进行多个转换,可以编写以下代码。

f = transformsCompose([transforms.Normalize(mean=0.0, std=2.0), transforms"AsType(np.float64)])

transformCompose类按顺序从头开始处理列表中的转换。上面的代码首先进行数据的正则化处理,然后将数据类型转换为np.float64。dezero/transform.py中内置了许多有用的转换处理,代码都很简单,这里不再赘述,感兴趣的读者可以查看。

步骤50

用于取出小批量数据的 DataLoader

上一个步骤创建了Dataset类,并建立了通过指定的接口访问数据集的机制。本步骤将实现Dataset类中创建小批量数据的DataLoader(数据加载器)类,让这个类来完成创建小批量数据和数据集重排等工作。由此,用户编写的训练代码会变得更加简洁。这里笔者先介绍迭代器(iterator),然后介绍如何实现DataLoader类。