16.3_Variable 类的 backward
16.3 Variable类的backward
言归正传,我们来看一下Variable类的backward方法是如何实现的。重点看修改的部分(阴影部分)。
steps/step16.py
class Variable:
def backward(self): if self.grad is None: self.grad \(=\) np.ones_like(self.data) {_ seen_set \(=\) set() def add_func(f): if f not in seen_set: {_funcs.append(f) {_seen_set.add(f) {_funcs.sort(key=lambda x:x_generation) add_func(self creator) while {_ \(\mathsf{f} =\) {_funcs.pop(\\(gys \(=\) [output.grad for output in f.outputs] \\)\(gxs = f.\) backward(*gys) if not isinstance(gxs,tuple): \)\mathrm{gxs} = (\mathrm{gxs},\mathrm{)}$for x, gx in zip(f-inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx
if x creator is not None:
add_func(x creator)上面的代码添加了add_func函数。此前向列表中添加DeZero函数时调用的是funcs.append(f),这里改为调用add_func函数。在这个add_func函数中,DeZero函数的列表将按照generation的值排序。这样一来,在之后取出DeZero函数时,我们就可以使用funcs.pop()取出generation的值最大的函数了。
顺带一提,上面的代码在backward方法中定义了add_func函数。这种用法适用于满足以下两个条件的情况。
只在父方法(backward方法)中使用
需要访问父方法 (backward 方法) 中使用的变量 (funcs、seen_set)
由于add_func函数满足这两个条件,所以我们将它定义在了方法中。

上面的实现使用了一个名为seen_set的集合(set)。该集合的作用是防止同一个函数被多次添加到funcs列表中,由此可以防止一个函数的backward方法被错误地多次调用。