13.3_Square类的实现

13.3 Square类的实现

现在Variable类和Function类已经支持可变长的输入和输出了。我们还实现了一个Add类作为具体的函数。最后,我们需要改进目前使用的Square类,使其支持新的Variable类和Function类。要修改的地方只有一处(阴影部分)。

steps/step13.py

class Square(Function): def forward(self, x): y = x ** 2 return y def backward(self, gy): x = self.inputs[0].data #修改前为x = self/input.data gx = 2 * x * gy return gx

如上面的代码所示,由于Function类的实例变量已经从input(单数形式)变为inputs(复数形式),所以Square需修改为从inputs中取出输入变量x。

这样新的Square类就完成了。下面使用add函数和square函数实际进行计算。

steps/step13.py

$\begin{array}{rl} & {\mathrm{x} = \mathrm{Variable}(\mathrm{np.array}(2.0))}\\ & {\mathrm{y} = \mathrm{Variable}(\mathrm{np.array}(3.0))}\\ & {\mathrm{z} = \mathrm{add}(\mathrm{square(x)},\mathrm{square(y)})}\\ & {\mathrm{z/backward()}}\\ & {\mathrm{print(z.data)}}\\ & {\mathrm{print(x.grad)}}\\ & {\mathrm{print(y.grad)}} \end{array}$

运行结果

13.0  
4.0  
6.0

上面的代码计算了 z=x2+y2z = x^{2} + y^{2} 。在使用DeZero的情况下,这个计算可以写成 z=add(square(x),square(y)z = \text{add(square(x)}, \text{square(y)} 的形式。之后只要调用z.reverse()就能自动求出导数了。

通过以上修改,我们实现了支持多个输入和输出的自动微分的机制。后面只要按部就班地编写需要的函数,就可以实现更复杂的计算。不过,当前的DeZero还存在一个问题。在下一个步骤,我们将解决这个问题。

步骤14

重复使用同一个变量

当前的DeZero有一个问题,每当重复使用同一个变量,这个问题就会出现。例如,图14-1所示的 y=add(x,x)y = \operatorname{add}(x, x) 的情况。


图14-1 y=add(x,x)y = \operatorname{add}(x, x) 的计算图

DeZero在用相同变量进行加法运算时不能正确求导。下面测试一下,看看实际的结果是什么样的。

$\mathbf{x} =$  Variable(np.array(3.0))   
y  $=$  add(x,x)   
print('y',y.data)   
y.backup()   
print('x.grad',x.grad)

运行结果

y 6.0  
x. grad 1.0

上面的代码以 x=3.0x = 3.0 进行了加法运算。在这种情况下,y的值是6.0,这是正确的结果。但是,x的导数(x.grad)是1.0,这是错误的结果。当

y=x+xy = x + x 时, y=2xy = 2x ,所以导数的正确结果为 yx=2\frac{\partial y}{\partial x} = 2