9.3_只支持ndarray

9.3 只支持ndarray

DeZero的Variable只支持ndarray实例的数据。但是,有些用户很可能会不小心使用float或int等数据类型,例如Variable(1.0)和Variable(3)等。考虑到这一点,我们再做一点优化,使Variable成为只能容纳ndarray实例的“箱子”。具体来说,就是当把ndarray实例以外的数据放入Variable时,让DeZero立即抛出错误(None是允许放入的)。这项改进有望让用户在早期发现问题。下面我们在Variable类的初始化部分添加以下代码。

steps/step09.py

class Variable: def__init__(self,data): ifdataisnotNone: ifnotisinstance(d.data,np.ndarray):} raiseTypeError('{}isnotsupported'.format(type(data))) self.data  $=$  data self.grad  $=$  None self creator  $=$  None

在上面的代码中,如果作为参数的 data 不是 None,也不是 ndarray 实例,就会引发 TypeError 异常。这时,程序会输出代码中指定的字符串作为错误提示。现在,可以像下面这样使用 Variable。

steps/step09.py

$\mathbf{x} =$  Variable(np.array(1.0)) # OK  
 $\mathbf{x} =$  Variable(None) # OK  
 $\mathbf{x} =$  Variable(1.0) # NG: 错误发生!

运行结果

TypeError: <class 'float'> is not supported

在上面的代码中,如果数据为ndarray或None,就可以顺利创建Variable。但如果是其他的数据类型,比如上面代码中的float,DeZero就会抛出一个异常。这样一来,用户就能立刻知道自己使用了错误的数据类型。

这项改进也带来一个问题,这个问题是由NumPy自身的特点所导致的。在解释这个问题之前,我们先看看下面的NumPy代码。

$\mathbf{x} = \mathbf{np}$  .array([1.0])   
y=x  $^{**}2$    
print(type(x),x.ndim)   
print(type(y))

运行结果

<class 'numpy.ndarray'> 1  
<class 'numpy.ndarray'>

代码中的x是一维的ndarray。x ** 2(平方)的结果y的数据类型是ndarray。这是预期的结果,问题出现在下面这种情况。

$\mathbf{x} = \mathbf{np}$  .array(1.0)   
y=x  $^{**}2$    
print(type(x),x.ndim)   
print(type(y))

运行结果

<class 'numpy.ndarray'> 0  
<class 'numpy.float64'>

上面代码中的 xx 是零维的ndarray,而 x×2x \times 2 的结果是np.float64,这是NumPy的运行方式①。换言之,如果用零维的ndarray实例进行计算,结果将是ndarray实例以外的数据类型,如numpy.float64、numpy.float32等。这意味着DeZero函数的输出Variable可能是numpy.float64或numpy.float32类型的数据。但是,Variable中的数据只允许保存ndarray实例。为了解决这个问题,我们首先准备以下函数作为工具函数。

steps/step09.py
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

上面的代码使用np.isscalar函数来检查numpy.float64等属于标量的类型(它也可以用来检查Python的int和float)。下面是使用np.isscalar函数的示例代码。

>>> import numpy as np
>>> np.isscalar(np.float64(1.0))
True
>>> np.isscalar(2.0)
True
>>> np.isscalar(np.array(1.0))
False
>>> np.isscalar(np.array([1, 2, 3]))
False

从这些例子可以看出,通过np.isscalar(x)可以判断x是否为ndarray实例。如果不是,则使用as_array函数将其转换为ndarray实例。实现as_array这个工具函数之后,在Function类中添加以下阴影部分的代码。

steps/step09.py
class Function: def__call__(self, input): x  $=$  input.data y  $=$  self.forward(x) output  $=$  Variable(as_array(y)) output.set creator(self) self-input  $\equiv$  input self.output  $=$  output return output

上面的代码在将正向传播的结果y封装在Variable中时使用了as_array(y),这样可以确保输出结果output是ndarray实例的数据。即使使用零维的ndarray实例进行计算,所有的数据也会是ndarray实例。

至此,本步骤的工作就完成了。下一个步骤的主题是测试DeZero。