40.2_DeZero的broadcast_to函数和sum_to函数
40.2 DeZero的broadcast_to函数和sum_to函数
DeZero的BroadcastTo类和broadcast_to函数如下所示。
dezero/functions.py
class BroadcastTo(Function): def __init__(self, shape): self.shape = shape def forward(self, x): self.x_shape = x.shape y = nproadcast_to(x, self.shape) return y def backward(self, gy):dezero/functions.py
gx = sum_to(gy, self.x_shape)
return gx
def broadcast_to(x, shape):
if x.shape == shape:
return as_variable(x)
return BroadcastTo(shape)(x)这里让我们把注意力放在反向传播的代码上。反向传播使用DeZero的sum_to函数将结果的形状变为输入 的形状。接下来实现这个sum_to函数。以下是SumTo类和sum_to函数的代码。
fromdezero importutils
class SumTo(Function): def__init__(self,shape): self.shape $\equiv$ shape def forward(self,x): self.x_shape $=$ x.shape y $=$ utils,sum_to(x,self.shape) returny def backward(self,gy): gx $=$ broadcast_to(gy,self.x_shape) returngx
def sum_to(x,shape): ifx.shape $\equiv$ shape: return as_variable(x) return SumTo(shape)(x)需要注意反向传播的代码。反向传播复制梯度的元素使结果的形状变为输入 的形状。在这个过程中用到了前面实现的DeZero的broadcast_to函数。从代码中可以看出,broadcast_to函数和sum_to函数相互依赖。这样我们就完成了DeZero的broadcast_to函数和sum_to函数。