46.2_SGD类的实现

46.2 SGD类的实现

现在来实现使用梯度下降法更新参数的类。下面是具体代码。

dezero/optimizers.py

class SGD(Optimizer):
    def __init__(self, lr=0.01):
        super().__init__()
        self.lr = lr
    def update_one(self, param):
        param.data -= self.lr * param_grad.data

SGD类继承于Optimizer类,初始化方法__init__接收学习率。之后的update_one方法中实现了更新参数的代码。这样就可以把参数更新交给SGD类来做了。另外,SGD类的代码实现在dezero/optimizers.py中。我们可通过fromdezero.optimizers import SGD从外部文件导入SGD。

SGD是Stochastic Gradient Descent的缩写,即随机梯度下降法。这里的随机(Stochastic)是指从对象数据中随机选择数据,并对所选数据应用梯度下降法。在深度学习领域,这种从原始数据中随机选择数据,并对这些数据应用梯度下降法的做法很常见。