18.1_不保留不必要的导数
18.1 不保留不必要的导数
第1项改进针对的是DeZero的反向传播。目前在DeZero中,所有的变量都保留了导数,比如下面这个例子。
$\begin{array}{l}\mathrm{x0} = \mathrm{Variable(np.array(1.0))}\\ \mathrm{x1} = \mathrm{Variable(np.array(1.0))}\\ \mathrm{t} = \mathrm{add(x0,x1)}\\ \mathrm{y} = \mathrm{add(x0,t)}\\ \mathrm{y的背后()} \end{array}$
print(y_grad,t_grad)
print(x0_grad,x1_grad)运行结果
1.0 1.0
2.0 1.0在上面的代码中,用户的提供的变量是 和 。变量 和 是通过计算产生的。当用 .backward() 计算导数时,所有变量都会保留它们的导数。不过在多数情况下,尤其在机器学习中,只有终端变量 的导数才需要通过反向传播求得。在上面的例子中, 或 等中间变量的导数基本用不到。因此,我们可以增加一种消除这些中间变量的导数的模式。为此,我们要在当前 Variable 类的 backward 方法中添加以下阴影部分的代码。
steps/step18.py
class Variable:
def backward(self, retain_grad=False): if self.grad is None: self.grad = np.ones_like(self.data)
funcs = [] seen_set = set()
def add_func(f): if f not in seen_set:funcs.append(f) seen_set.add(f)funcs.sort(key=lambda x:x-generation)
add_func(self creator)
whilefuncs: f =funcs.pop() gys = [output().grad for output in f.outputs] gxs = f.backup(*gys) if not isinstance(gxs,tuple): gxs = (gxs,) for x,gx in zip(finputs,gxs): if x.grad is None: x.grad = gx else: x.grad = x.grad + gx if x creator is not None: add_func(x creator)
if not retain_grad: for y in f.outputs: y().grad = None #y是weakref上面的代码首先添加 retain_grad 作为方法的参数。如果 retain_grad 为 True,那么所有的变量都会像之前一样保留它们的导数(梯度)。如果 retain_grad 为 False(默认为 False),那么所有中间变量的导数都会被重置。其原理就在于 backward 方法的 for 语句末尾的 y().grad = None,这行代码的意思是不要保留各函数输出变量的导数。这样一来,除终端变量外,其他变量的导数都不会被保留。

之所以写成y().grad = None,是因为y是弱引用,必须以y()的形式访问(上一个步骤已经引入了弱引用机制)。另外,设置y().grad = None之后,引用计数将变为0,导数的数据会从内存中被删除。
再次运行前面的测试代码。
steps/step18.py
$\mathbf{x}\theta =$ Variable(np.array(1.0))
x1 $=$ Variable(np.array(1.0))
t $=$ add(x0,x1)
y $=$ add(x0,t)
y.backup()
print(y.grad,t.grad)
print(x0.grad,x1.grad)运行结果
None None 2.01.0在上面的代码中,中间变量y和t的导数已经被删除。占用的内存空间被立即释放了出来。这样就完成了DeZero在内存使用上的第1项改进。下面进行第2项改进。在正式操作之前,我们先来回顾一下当前Function类的代码。