32.2_函数类的反向传播
32.2 函数类的反向传播
剩下的工作就是修改backward方法(不修改Function类)。之前,我们已经在dezero/core/simple.py文件中实现了以下DeZero的函数类。
Add
Mul
Neg
Sub
DivPow
我们将修改这些类的 backward 方法, 然后将它们添加到dezero/core.py中。首先从Add类开始, 不过Add类不需要做任何修改。这里我们看一下Add类的实现, 代码如下所示。
dezero/core.py
class Add(Function): def forward(self, x0, x1): y = x0 + x1 return y def backward(self, gy): return gy, gyAdd类的反向传播只是将导数从输出端向输入端传递而已。反向传播中没有计算任何内容,所以没有需要修改的代码。
下一个是Mul类。将Mul类的backward方法按图32-1进行修改。
class Mul(Function): def backward(self, gy): x0 = self.inputs[0].data x1 = selfinputs[1].data return gy * x1, gy * x0class Mul Function): def backward(self, gy): x0, x1 = self.inputs return gy \*x1, gy \*x0如图32-1所示,之前我们需要从Variable实例中取出数据(ndarray实例),而在新的Mul类中,Variable实例可以直接使用。
图32-1中需要大家注意的是进行反向传播的代码gy * x1。再次强调,在新的DeZero中,gy和x1是Variable实例。我们已经在Variable类上实现了*运算符的重载,因此在执行gy * x1的背后,Mul类的正向传播会被调用。此时,Function.call( )会被调用,该方法中会构建计算图。

图32-1 Mul类的backward方法的对比(左边是旧代码,右边是新代码)
反向传播的计算针对的是Variable实例,所以我们需要使用DeZero函数对Variable实例进行计算。
之后按照同样的步骤对Sub类、Div类和Pow类修改backward方法即可。修改方法与前面介绍的内容相同,本书就不再一一介绍了。
32.3 实现更有效的反向传播(增加模式控制代码)
我们在步骤18中引入了启用/禁用反向传播的模式。具体来说,当不需要反向传播时,切换到“禁用反向传播模式”,以此来省略用于反向传播的处理(如创建计算图和保存输入变量等)。这里对反向传播中进行的计算使用同样的机制。也就是说,对于在反向传播中进行的计算,如果不想再次反向
传播了,即只进行一次反向传播,就要在“禁用反向传播模式”下进行反向传播的计算。为了实现这个机制,我们需要在Variable类的backward方法中添加以下代码。
dezero/core.py
def backward(self, retain_grad=False, create_graph=False):
...
while func:
f = func.pop()
gys = [output().grad for output in f.outputs]
with using_config('enable_backward', create_graph):
gxs = f_backward(*gys) # 主要的backward处理
if not isinstance(gxs, tuple):
gxs = (gxs,)
for x, gx in zip(f-inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx # 这个计算也是对象
if x creator is not None:
add_func(x creator)首先增加create_graph参数,并将默认值设置为False。然后在with using_config(...)中进行实际的反向传播处理(步骤18中已经解释过using_config函数的用法,这里不再赘述)。这意味着当create_graph为False时,反向传播中的计算是在禁用反向传播模式下进行的。

这部分内容有点复杂,笔者用具体的例子来补充说明。例如,在进行Mul类的反向传播时,其backward方法执行gy * x1的计算。因为*运算符被重载,所以代码Mul()(gy, x1)会被调用,这会触发父类Function.call( )被调用。Function.call( 方法会根据Config enable_backprop的值来启用或禁用反向传播。
为什么让create_graph=False作为默认设置呢?这是因为只需要一次反向传播的情况占大多数。如果需要求二阶导数,将参数设置为True即可。在这种情况下,反向传播的计算会创建出新的计算图,反向传播得以继续进行。