46.2_SGD类的实现
46.2 SGD类的实现
现在来实现使用梯度下降法更新参数的类。下面是具体代码。
dezero/optimizers.py
class SGD(Optimizer):
def __init__(self, lr=0.01):
super().__init__()
self.lr = lr
def update_one(self, param):
param.data -= self.lr * param_grad.dataSGD类继承于Optimizer类,初始化方法__init__接收学习率。之后的update_one方法中实现了更新参数的代码。这样就可以把参数更新交给SGD类来做了。另外,SGD类的代码实现在dezero/optimizers.py中。我们可通过fromdezero.optimizers import SGD从外部文件导入SGD。

SGD是Stochastic Gradient Descent的缩写,即随机梯度下降法。这里的随机(Stochastic)是指从对象数据中随机选择数据,并对所选数据应用梯度下降法。在深度学习领域,这种从原始数据中随机选择数据,并对这些数据应用梯度下降法的做法很常见。