48.1_螺旋数据集
48.1 螺旋数据集
DeZero 有一个模块 (文件) 叫dezero/datasets.py。该模块包含与数据集有关的类和函数,它还内置了一些典型的机器学习的数据集。这里使用函数 get Spiral 加载其中的螺旋数据集。下面是一个简单的使用示例。
importdezero
x,t $=$ dezero.datasets.get Spiral(train $\equiv$ True)
print(x.shape)
print(t.shape)
print(x[10],t[10])
print(x[110],t[110])运行结果
(300,2)
(300,)
[0.05984409 0.0801167] 0
[-0.08959206 -0.04442143] 1get Spiral函数从参数获取标志位train。如果train=True,则返回训练数据;如果train=False,则返回测试数据。实际返回的值是x和t,x是输入数据,t是训练数据(标签)。这里的x是形状为(300,2)的ndarray实例,t是形状为(300,)的ndarray实例。这里处理的是3类分类问题,t的元素值是0、1、2其中之一。图48-1展示了呈螺旋状分布的数据集。

图48-1 呈螺旋状分布的数据集
图48-1使用了 、 和 这3种不同的符号来绘制每个类别的数据点。从图中可以看出这是一个呈螺旋状分布的数据集。下面我们使用神经网络来看看能否对这些数据正确地进行分类。