13.2_修改Variable类

13.2 修改Variable类

现在来看看Variable类的backward方法。复习一下当前Variable类的代码,具体如下所示。

steps/step12.py

class Variable: def backward(self):
if self.grad is None:  
    self.grad = np.ones_like(self.data)  
funcs = [self creator]  
whilefuncs:  
    f =funcs.pop()  
    x, y = f-input, f.output # ①获取函数的输入和输出  
    x.grad = f_backward(y.grad) # ②调用backward方法  
    if x creator is not None:  
       funcs.append(x creator)

这里需要注意的是阴影部分的代码。首先,while循环中①的部分用于获取函数的输入输出变量。②的部分用于调用函数的backward方法。目前,①的代码只支持函数的输入输出变量只有一个的情况。下面我们对代码进行修改,使其能够支持多个变量。修改后的代码如下所示。

steps/step13.py

class Variable:   
def backward(self): if self.grad is None: self.grad  $=$  np.ones_like(self.data) funcs  $=$  [self creator] while funcs: f  $=$  funcs.pop() gys  $=$  [output.grad for output in f.outputs] # ① gxs  $=$  f.backup(\*gys) # ② if not isinstance(gxs,tuple): # ③ gxs  $=$  (gxs,) for x,gx in zip(finputs,gxs): # ④ x.grad  $=$  gx if x creator is not None: funcs.append(x creator)

这里共修改了4处。①处将输出变量outputs的导数汇总在列表中。②

处调用了函数f的反向传播。这里调用了f.backup(*gys)这种参数前带有星号的函数对列表进行解包(展开)。③处所做的处理是,当gxs不是元组时,将其转换为元组。

代码中的②和③处与上一个步骤改进正向传播的做法相同。②的代码在调用Add类的backward方法时,将参数解包后传递。③的代码使得Add类的backward方法可以简单地返回元素而不是元组。

代码中的④处将反向传播中传播的导数设置为Variable的实例变量grad。这里,gxs和finputs的每个元素都是一一对应的。准确来说,如果有第i个元素,那么f-input[i]的导数值对应于gxs[i]。于是代码中使用zip函数和for循环来设置每一对的导数。以上就是Variable类的新backward方法。