48.1_螺旋数据集

48.1 螺旋数据集

DeZero 有一个模块 (文件) 叫dezero/datasets.py。该模块包含与数据集有关的类和函数,它还内置了一些典型的机器学习的数据集。这里使用函数 get Spiral 加载其中的螺旋数据集。下面是一个简单的使用示例。

importdezero   
x,t  $=$ dezero.datasets.get Spiral(train  $\equiv$  True)   
print(x.shape)   
print(t.shape)   
print(x[10],t[10])   
print(x[110],t[110])

运行结果

(300,2)  
(300,)  
[0.05984409 0.0801167] 0  
[-0.08959206 -0.04442143] 1

get Spiral函数从参数获取标志位train。如果train=True,则返回训练数据;如果train=False,则返回测试数据。实际返回的值是x和t,x是输入数据,t是训练数据(标签)。这里的x是形状为(300,2)的ndarray实例,t是形状为(300,)的ndarray实例。这里处理的是3类分类问题,t的元素值是0、1、2其中之一。图48-1展示了呈螺旋状分布的数据集。


图48-1 呈螺旋状分布的数据集

图48-1使用了 \bigcircΔ\Delta×\times 这3种不同的符号来绘制每个类别的数据点。从图中可以看出这是一个呈螺旋状分布的数据集。下面我们使用神经网络来看看能否对这些数据正确地进行分类。