22.4_幂运算

22.4 幂运算

幂运算用式子 y=xcy = x^c 表示,其中 xx 称为底, cc 称为指数。由导数公式可知,幂的导数为 yx=cxc1\frac{\partial y}{\partial x} = cx^{c - 1} 。至于 yc\frac{\partial y}{\partial c} ,在实践中需要计算它的情况并不多(当然也可以计算它),因此本书只考虑计算底 xx 的导数的情况。也就是说,我们将指数 cc 视为常数,不去计算它的导数。考虑到这一点,可将代码编写如下。

steps/step22.py

class Pow(Function): def __init__(self, c): self.c = c def forward(self, x): y = x ** self.c return y def backward(self, gy): x = self.inputs[0].data c = self.c gx = c * x ** (c - 1) * gy return gx def pow(x, c): return Pow(c)(x) Variable._pow_ = pow

上面的代码在 Pow 类初始化时设置指数 c。正向传播 forward(x) 只接受一个变量,即底 x。最后一行代码的意思是将函数 pow 赋给特殊方法 pow。这样,我们就可以使用 **运算符来进行幂运算了。下面是以上代码的应用示例。

steps/step22.py

$\begin{array}{rl} & {\mathrm{x} = \mathrm{Variable}(\mathrm{np.array}(2.\theta))}\\ & {\mathrm{y} = \mathrm{x}^{**}3}\\ & {\mathrm{print}(\mathrm{y})} \end{array}$

运行结果

variable(8.0)

到这里就完成了添加运算符的工作。本步骤虽然有些枯燥,但 DeZero 的可用性得到了大幅提升。现在我们可以用各种运算符自由地进行四则运算了,甚至还能进行幂运算,能实现的计算也越来越复杂。在下一个步骤,我们会把现有的成果整理成一个 Python 的包,然后验证当前的 DeZero 所具备的能力。