13.2_修改Variable类
13.2 修改Variable类
现在来看看Variable类的backward方法。复习一下当前Variable类的代码,具体如下所示。
steps/step12.py
class Variable: def backward(self):if self.grad is None:
self.grad = np.ones_like(self.data)
funcs = [self creator]
whilefuncs:
f =funcs.pop()
x, y = f-input, f.output # ①获取函数的输入和输出
x.grad = f_backward(y.grad) # ②调用backward方法
if x creator is not None:
funcs.append(x creator)这里需要注意的是阴影部分的代码。首先,while循环中①的部分用于获取函数的输入输出变量。②的部分用于调用函数的backward方法。目前,①的代码只支持函数的输入输出变量只有一个的情况。下面我们对代码进行修改,使其能够支持多个变量。修改后的代码如下所示。
steps/step13.py
class Variable:
def backward(self): if self.grad is None: self.grad $=$ np.ones_like(self.data) funcs $=$ [self creator] while funcs: f $=$ funcs.pop() gys $=$ [output.grad for output in f.outputs] # ① gxs $=$ f.backup(\*gys) # ② if not isinstance(gxs,tuple): # ③ gxs $=$ (gxs,) for x,gx in zip(finputs,gxs): # ④ x.grad $=$ gx if x creator is not None: funcs.append(x creator)这里共修改了4处。①处将输出变量outputs的导数汇总在列表中。②
处调用了函数f的反向传播。这里调用了f.backup(*gys)这种参数前带有星号的函数对列表进行解包(展开)。③处所做的处理是,当gxs不是元组时,将其转换为元组。

代码中的②和③处与上一个步骤改进正向传播的做法相同。②的代码在调用Add类的backward方法时,将参数解包后传递。③的代码使得Add类的backward方法可以简单地返回元素而不是元组。
代码中的④处将反向传播中传播的导数设置为Variable的实例变量grad。这里,gxs和finputs的每个元素都是一一对应的。准确来说,如果有第i个元素,那么f-input[i]的导数值对应于gxs[i]。于是代码中使用zip函数和for循环来设置每一对的导数。以上就是Variable类的新backward方法。