49.1_Dataset类的实现

49.1 Dataset类的实现

Dataset类是作为基类实现的。我们让用户实际使用的数据集类继承Dataset类。Dataset类的代码如下所示。

dezero/datasets.py

import numpy as np

class Dataset:

def__init__(self,train  $\equiv$  True): self.train  $=$  train self.data  $=$  None self.label  $=$  None self.prepare()   
def__getitem__(self,index): assert np.isscalar(index)#只支持index是整数(标量)的情况 if self.label is None: return self.data[index],None else: return self.data[index],self.label[index]   
deflen__(self): return len(self.data)   
defprepare(self): pass

首先,初始化方法接收train参数。它是用于区分“训练”和“测试”的标志位。另外,Dataset类由保存输入数据的实例变量data和保存标签的实例变量label构成。之后调用的prepare方法用于准备数据,用户需要在继承了Dataset的类中实现这个方法。

Dataset类中最重要的是__getitem__和__len__这两个方法。拥有这两个方法(接口)是DeZero数据集的要求。只要固定了接口,我们就可以切换使用各种数据集。

getitem 是一个特殊的 Python 方法,它定义了通过方括号访问元素(如 x[0] 和 x[1] 等)时的操作。Dataset 类的 getitem 方法仅用来取出指定索引处的数据。如果没有标签数据,它将返回输入数据 self.data[index] 和标签数据 None(这是无监督学习的情况)。另外,len 方法在使用 len 函数时被调用(如 len(x)),它用于查看数据集的长度。

除了int类型,__*_方法原本还支持切片参数。切片是指x[1:3]这样的操作。不过DeZero的Dataset类不支持切片操作,只支持int类型的索引。

以上就是Dataset类的代码。下面扩展Dataset类以实现螺旋数据集类。代码如下所示,类名为Spiral。

dezero/datasets.py

class Spiral(Dataset): def prepare(self): self.data, self.label = get_spiral(self.train)

上面代码的prepare方法仅仅将数据设置为实例变量data和label。现在我们可以像下面这样使用Spiral类来取出数据了,还可以得到数据的长度。

importdezero   
train_set  $\equiv$ dezero.datasets.Spiral(train=True)   
print(train_set[0])   
print(len(train_set))

运行结果

(array([-0.13981389, -0.00721657], dtype=float32), 1) 300

上面的代码通过train_set[0]访问数据,第0个输入数据和标签会以元组的形式返回。