20.2_运算符重载
20.2 运算符重载
下面先重载乘法运算符*。乘法的特殊方法是__mul__(self, other)(参数 self 和 other 的相关内容将在后面解释)。如果定义(实现)了__mul__方法,那么在使用*进行计算时,__mul__方法就会被调用。下面试着实现Variable类的__mul__方法,代码如下所示。
Variable: def __mul__(self, other): return mul(self, other)上面的代码向已经实现的Variable类中添加了__mul__方法。这样在使用进行计算时,被调用的就是__mul__方法,这个方法内部又会调用mul函数。下面我们用运算符做一些计算。
a = Variable(np.array(3.0))
b = Variable(np.array(2.0))
y = a * b
print(y)运行结果
variable(6.0)上面的代码成功实现了 的计算。当执行 时,实例a的__mul__(self, other)方法被调用。这时,运算符*左侧的a作为self参数、右侧的b作为other参数传给了__mul__方法,具体如图20-2所示。

图20-2 向__mul__方法传递参数的示意图

在上面的例子中,当执行 的代码时,首先实例a的特殊方法 mul 方法会被调用。如果a中没有实现 mul 方法,那么实例b中运算符的特殊方法会被调用。在这个例子中b在运算符的右侧,所以调用的特殊方法是 rmul。
这样就完成了*运算符的重载。为此,我们实现了Variable类的__mul__方法。下面的代码可以达到同样的目的。
steps/step20.py
class Variable: Variable._mul_ $=$ mul Variable._add_ $=$ add如上所示,在定义Variable类后,又写了Variable.mul = mul。在Python中函数也是对象,所以我们可以按上面的公式把函数赋给方法。于是,在调用Variable实例的__mul__方法时,mul函数会被调用。
另外,上面的代码还设置了运算符 的特殊方法__add__。这样就实现了 运算符的重载。下面使用 和 进行一些计算。
steps/step20.py
a = Variable(np.array(3.0))
b = Variable(np.array(2.0))
c = Variable(np.array(1.0))
# y = add(mul(a, b), c)
y = a * b + c
y.backup()
print(y)
print(a.grad)
print(b.grad)运行结果
variable(7.0)
2.0
3.0上面的代码成功实现了 的计算。现在可以使用 和 自由地进行计算了。基于同样的做法还可以实现其他运算符(如/和-等)的重载。在下一个步骤,我们将继续实现这部分的内容。
步骤21
运算符重载(2)
我们的DeZero越来越好用了。在有Variable实例a和b的情况下,我们可以写出a * b或a + b这样的代码。不过,现在还不能使用a * np.array(2.0)这种将Variable实例与ndarray实例结合起来的代码,也不能使用3 + b这种将Variable实例与数值数据结合起来的代码。如果能将Variable实例与ndarray实例和数值数据结合使用,DeZero会更加好用。本步骤将扩展Variable,使Variable实例能够与ndarray实例,以及int和float等类型的数据一起使用。