40.2_DeZero的broadcast_to函数和sum_to函数

40.2 DeZero的broadcast_to函数和sum_to函数

DeZero的BroadcastTo类和broadcast_to函数如下所示。

dezero/functions.py

class BroadcastTo(Function): def __init__(self, shape): self.shape = shape def forward(self, x): self.x_shape = x.shape y = nproadcast_to(x, self.shape) return y def backward(self, gy):

dezero/functions.py

gx = sum_to(gy, self.x_shape)  
return gx  
def broadcast_to(x, shape):  
    if x.shape == shape:  
        return as_variable(x)  
    return BroadcastTo(shape)(x)

这里让我们把注意力放在反向传播的代码上。反向传播使用DeZero的sum_to函数将结果的形状变为输入 xx 的形状。接下来实现这个sum_to函数。以下是SumTo类和sum_to函数的代码。

fromdezero importutils   
class SumTo(Function): def__init__(self,shape): self.shape  $\equiv$  shape def forward(self,x): self.x_shape  $=$  x.shape y  $=$  utils,sum_to(x,self.shape) returny def backward(self,gy): gx  $=$  broadcast_to(gy,self.x_shape) returngx   
def sum_to(x,shape): ifx.shape  $\equiv$  shape: return as_variable(x) return SumTo(shape)(x)

需要注意反向传播的代码。反向传播复制梯度的元素使结果的形状变为输入 xx 的形状。在这个过程中用到了前面实现的DeZero的broadcast_to函数。从代码中可以看出,broadcast_to函数和sum_to函数相互依赖。这样我们就完成了DeZero的broadcast_to函数和sum_to函数。