46.1_Optimizer 类

46.1 Optimizer类

本节把进行参数更新的基础类实现为Optimizer(优化器)类。Optimizer类是执行优化操作的基类。我们需要在继承了Optimizer的类中实现具体的优化方法。Optimizer类的实现如下所示。

dezero/optimizers.py

class Optimizer: def__init__(self): self.target  $=$  None self.hooks  $\equiv$  [] def setup(self,target): self.target  $=$  target return self defupdate(self): #将None之外的参数汇总到列表 params  $=$  [p for p in self.target.params() if p.grad is not None]
#预处理(可选)  
forf in self投机:f.params)  
#更新参数  
forparaminparams:self.updateone(param)  
defupdateone(self,param):raiseNotImplementedError()  
defadd hookself,f):self投机.append(f)

Optimizer类在初始化阶段初始化了两个实例变量,分别是target和hooks。然后,通过setup方法将作为类实例(Model或Layer)的参数变量设置为target实例变量。

Optimizer类的update方法对除grad实例变量为None的参数之外的其他参数进行了更新。此外,具体的参数更新通过update_one方法进行。update_one方法在继承Optimizer的类中通过重写来实现。

此外,Optimizer类还具有在更新参数之前对所有参数进行预处理的功能。用户可使用add hook方法来添加进行预处理的函数。这个机制可用于权重衰减、梯度裁剪(参考文献[26])等(实现示例在example/mnist.py等文件中)。