49.5_数据集的预处理
49.5 数据集的预处理
在向机器学习的模型输入数据之前,通常要对数据进行一定的处理,比如从数据中减去某个值,或者改变数据的形状。常见的处理还有数据增强,即通过旋转或翻转图像等方式来人为地增加数据。为了支持这些预处理(以及数据增强),我们向Dataset类中添加以下实现预处理功能的代码。
dezero/datasets.py
class Dataset: def__init__(self,train $\equiv$ True,transform $\equiv$ None,target_transform $\equiv$ None): self.train $=$ train self.transform $=$ transform self.target_transform $=$ target_transform if self.transform is None: self.transform $=$ lambda x:x if self.target_transform is None: self.target_transform $=$ lambda x:x self.data $=$ None self.label $=$ None self.prepare()
def__getitem_(self,index): assert np.isscalar(index) if self.label is None: return self.transform(self.data[index]),None else: return self.transform(self.data[index]),\,\, $\mathrm{sim}$ .target_transform(self.label[index])
def__len__(self): return len(self.data)
defprepare(self): pass上面的代码在初始化阶段接收transform和target_transform等新的参数。这些参数是可调用的对象(如Python的函数等)。transform对单个输入数据进行转换处理,target_transform对单个标签进行转换处理。如果传来的参数是None,预处理则会被设置为lambda x:x,lambda表达式可以直接返回参数(即不做预处理)。有了设置预处理的功能后,我们可以写出以下代码。
def f(x):
y = x / 2.0
return y
train_set =dezero.datasets.Spiral(transform $\equiv$ f)上面的示例代码针对输入数据执行了将其缩放为一半的预处理。用户可以像这样对数据集添加任何预处理。DeZero在dezero/transforms.py中内置了常用的预处理转换,比如数据正则化处理及与图像数据(PIL.Image实例)相关的转换处理。下面是一个实际的使用示例。
fromdezero import transforms
f $=$ transforms.Normalize(mean $\coloneqq$ 0.0,std $\coloneqq$ 2.0)
train_set $=$ dezero.datasets.Spiral(transforms $\equiv$ f)如果输入是x,那么上面代码中的transform.Normalize(mean=0.0, std=2.0)就会将x转换为(x - mean) / std。如果想连续进行多个转换,可以编写以下代码。
f = transformsCompose([transforms.Normalize(mean=0.0, std=2.0), transforms"AsType(np.float64)])transformCompose类按顺序从头开始处理列表中的转换。上面的代码首先进行数据的正则化处理,然后将数据类型转换为np.float64。dezero/transform.py中内置了许多有用的转换处理,代码都很简单,这里不再赘述,感兴趣的读者可以查看。
步骤50
用于取出小批量数据的 DataLoader
上一个步骤创建了Dataset类,并建立了通过指定的接口访问数据集的机制。本步骤将实现Dataset类中创建小批量数据的DataLoader(数据加载器)类,让这个类来完成创建小批量数据和数据集重排等工作。由此,用户编写的训练代码会变得更加简洁。这里笔者先介绍迭代器(iterator),然后介绍如何实现DataLoader类。