18.2_回顾Function类
18.2 回顾 Function 类
在DeZero中计算导数时,要先进行正向传播再进行反向传播。反向传播阶段需要正向传播阶段的计算结果,所以我们需要保存(记住)这些结果。下面的Function类的阴影区域就是用来实际保存计算结果的代码。
steps/step18.py
class Function: def__call__(self,\*inputs): xs $=$ [x.data for x in inputs] ys $=$ self.forward(\*xs) if not isinstance(ys,tuple): ys $=$ (ys,) outputs $=$ [Variable(as_array(y)) for y in ys] self_generation $=$ max([x_generation for x in inputs]) for output in outputs: output.set creator(self) selfInputs $\equiv$ inputs self.outputs $=$ [weakref.ref(output) for output in outputs] return outputs if lenoutputs) $>1$ else outputs[0]在上面的代码中,函数的输入被一个名为inputs的实例变量引用。inputs引用的变量,其引用计数增加了1。这意味着在调用__call__方法后,inputs引用的变量将继续保留在内存中。如果此时不再引用inputs,那么引用计数将变为0,inputs将被从内存中删除。
实例变量inputs用于反向传播的计算。因此,在进行反向传播时,要保留inputs所引用的变量。不过有些时候并不需要求导,在这种情况下,我们没有必要保留中间计算的结果,也没有必要在计算之间创建“连接”。

神经网络分为训练(train)和推理(inference)两个阶段。在训练阶段需要求出导数,在推理阶段只进行正向传播。在只进行正向传播时,我们可以把中间的计算结果“扔掉”,这将大幅缩减内存的使用量。