32.2_函数类的反向传播

32.2 函数类的反向传播

剩下的工作就是修改backward方法(不修改Function类)。之前,我们已经在dezero/core/simple.py文件中实现了以下DeZero的函数类。

Add
Mul

  • Neg
    Sub
    Div

  • Pow

我们将修改这些类的 backward 方法, 然后将它们添加到dezero/core.py中。首先从Add类开始, 不过Add类不需要做任何修改。这里我们看一下Add类的实现, 代码如下所示。

dezero/core.py

class Add(Function): def forward(self, x0, x1): y = x0 + x1 return y def backward(self, gy): return gy, gy

Add类的反向传播只是将导数从输出端向输入端传递而已。反向传播中没有计算任何内容,所以没有需要修改的代码。

下一个是Mul类。将Mul类的backward方法按图32-1进行修改。

class Mul(Function): def backward(self, gy): x0 = self.inputs[0].data x1 = selfinputs[1].data return gy * x1, gy * x0
class Mul Function): def backward(self, gy): x0, x1 = self.inputs return gy \*x1, gy \*x0

如图32-1所示,之前我们需要从Variable实例中取出数据(ndarray实例),而在新的Mul类中,Variable实例可以直接使用。

图32-1中需要大家注意的是进行反向传播的代码gy * x1。再次强调,在新的DeZero中,gy和x1是Variable实例。我们已经在Variable类上实现了*运算符的重载,因此在执行gy * x1的背后,Mul类的正向传播会被调用。此时,Function.call( )会被调用,该方法中会构建计算图。


图32-1 Mul类的backward方法的对比(左边是旧代码,右边是新代码)

反向传播的计算针对的是Variable实例,所以我们需要使用DeZero函数对Variable实例进行计算。

之后按照同样的步骤对Sub类、Div类和Pow类修改backward方法即可。修改方法与前面介绍的内容相同,本书就不再一一介绍了。

32.3 实现更有效的反向传播(增加模式控制代码)

我们在步骤18中引入了启用/禁用反向传播的模式。具体来说,当不需要反向传播时,切换到“禁用反向传播模式”,以此来省略用于反向传播的处理(如创建计算图和保存输入变量等)。这里对反向传播中进行的计算使用同样的机制。也就是说,对于在反向传播中进行的计算,如果不想再次反向

传播了,即只进行一次反向传播,就要在“禁用反向传播模式”下进行反向传播的计算。为了实现这个机制,我们需要在Variable类的backward方法中添加以下代码。

dezero/core.py

def backward(self, retain_grad=False, create_graph=False):
    ...
while func:
    f = func.pop()
    gys = [output().grad for output in f.outputs]
with using_config('enable_backward', create_graph):
        gxs = f_backward(*gys)  # 主要的backward处理
        if not isinstance(gxs, tuple):
            gxs = (gxs,) 
        for x, gx in zip(f-inputs, gxs):
            if x.grad is None:
                x.grad = gx
            else:
                x.grad = x.grad + gx  # 这个计算也是对象
            if x creator is not None:
                add_func(x creator)

首先增加create_graph参数,并将默认值设置为False。然后在with using_config(...)中进行实际的反向传播处理(步骤18中已经解释过using_config函数的用法,这里不再赘述)。这意味着当create_graph为False时,反向传播中的计算是在禁用反向传播模式下进行的。

这部分内容有点复杂,笔者用具体的例子来补充说明。例如,在进行Mul类的反向传播时,其backward方法执行gy * x1的计算。因为*运算符被重载,所以代码Mul()(gy, x1)会被调用,这会触发父类Function.call( )被调用。Function.call( 方法会根据Config enable_backprop的值来启用或禁用反向传播。

为什么让create_graph=False作为默认设置呢?这是因为只需要一次反向传播的情况占大多数。如果需要求二阶导数,将参数设置为True即可。在这种情况下,反向传播的计算会创建出新的计算图,反向传播得以继续进行。