20.2_运算符重载

20.2 运算符重载

下面先重载乘法运算符*。乘法的特殊方法是__mul__(self, other)(参数 self 和 other 的相关内容将在后面解释)。如果定义(实现)了__mul__方法,那么在使用*进行计算时,__mul__方法就会被调用。下面试着实现Variable类的__mul__方法,代码如下所示。

Variable: def __mul__(self, other): return mul(self, other)

上面的代码向已经实现的Variable类中添加了__mul__方法。这样在使用进行计算时,被调用的就是__mul__方法,这个方法内部又会调用mul函数。下面我们用运算符做一些计算。

a = Variable(np.array(3.0))  
b = Variable(np.array(2.0))  
y = a * b  
print(y)

运行结果

variable(6.0)

上面的代码成功实现了 y=aby = a * b 的计算。当执行 aba * b 时,实例a的__mul__(self, other)方法被调用。这时,运算符*左侧的a作为self参数、右侧的b作为other参数传给了__mul__方法,具体如图20-2所示。


图20-2 向__mul__方法传递参数的示意图

在上面的例子中,当执行 aba * b 的代码时,首先实例a的特殊方法 mul 方法会被调用。如果a中没有实现 mul 方法,那么实例b中运算符的特殊方法会被调用。在这个例子中b在运算符的右侧,所以调用的特殊方法是 rmul

这样就完成了*运算符的重载。为此,我们实现了Variable类的__mul__方法。下面的代码可以达到同样的目的。

steps/step20.py   
class Variable: Variable._mul_  $=$  mul Variable._add_  $=$  add

如上所示,在定义Variable类后,又写了Variable.mul = mul。在Python中函数也是对象,所以我们可以按上面的公式把函数赋给方法。于是,在调用Variable实例的__mul__方法时,mul函数会被调用。

另外,上面的代码还设置了运算符 +^+ 的特殊方法__add__。这样就实现了 ++ 运算符的重载。下面使用 +^+* 进行一些计算。

steps/step20.py

a = Variable(np.array(3.0))  
b = Variable(np.array(2.0))  
c = Variable(np.array(1.0))  
# y = add(mul(a, b), c)  
y = a * b + c  
y.backup()  
print(y)  
print(a.grad)  
print(b.grad)

运行结果

variable(7.0)  
2.0  
3.0

上面的代码成功实现了 y=ab+cy = a * b + c 的计算。现在可以使用 ++* 自由地进行计算了。基于同样的做法还可以实现其他运算符(如/和-等)的重载。在下一个步骤,我们将继续实现这部分的内容。

步骤21

运算符重载(2)

我们的DeZero越来越好用了。在有Variable实例a和b的情况下,我们可以写出a * b或a + b这样的代码。不过,现在还不能使用a * np.array(2.0)这种将Variable实例与ndarray实例结合起来的代码,也不能使用3 + b这种将Variable实例与数值数据结合起来的代码。如果能将Variable实例与ndarray实例和数值数据结合使用,DeZero会更加好用。本步骤将扩展Variable,使Variable实例能够与ndarray实例,以及int和float等类型的数据一起使用。

20.2_运算符重载 - 深度学习自制框架 | OpenTech