38.1_reshape函数的实现

38.1 reshape函数的实现

现在来实现变换张量形状的函数。在此之前,我们先确认一下NumPy的reshape函数的用法。编写np.reshape(x, shape)这样的代码,可以将x转换为shape的形状。下面是使用示例。

import numpy as np  
x = np.array([[1, 2, 3], [4, 5, 6]])  
y = np.reshape(x, (6))  
print(y)

运行结果

[1,2,3,4,5,6]

上面的代码将 xx 的形状由 (2, 3) 变换为 (6,)。张量中的元素数量没有改变,只有形状发生了变化。现在来实现 DeZero 版本的 reshape 函数。这里的问题是如何实现它的反向传播。

针对不会逐元素进行计算的函数,以张量的形状作为切入点,会使反向传播的实现变得清晰。具体来说,就是要确保变量的数据与梯度的形状一致。假设有Variable实例x,这时反向传播的实现需要确保x.data.shape == x.grad.shape。

reshape函数只是对形状进行变换,也就是说,它不进行具体的计算。因此在反向传播的过程中,reshape函数对从输出端传来的梯度不进行任何修改,直接将其传给输入端。不过,如图38-1所示,梯度的形状会变得与输入的形状相同。


图38-1 reshape函数的正向传播和反向传播的计算图(执行反向传播的函数用reshape' 表示,并使用伪梯度(a,b,c,d,e,f))

在图38-1中,反向传播从输出端传播梯度。为了使x.data.shape和x.grad.shape相等,我们对梯度进行转换。具体来说,就是将形状为(6,)的梯度的形状转换为(2,3)的形状,也就是将它转换成输入变量的形状。这就是reshape函数的反向传播。根据上述内容,我们来实现DeZero的reshape函数。

dezero/functions.py

class Reshape(Function): def __init__(self, shape): self.shape = shape def forward(self, x): self.x_shape = x.shape y = x.reshape(self.shape) return y def backward(self, gy): return reshape(gy, self.x_shape)

首先,在初始化reshape类的过程中,reshape类的初始化方法__init__会接收要转换的形状,并将其保存为shape。然后,正向传播的forward方法使用NumPy的reshape函数对形状进行转换。该方法使用self.x_shape = x.shape保存输入x的形状。于是,在反向传播的backward方法中,梯度的形状会转换为输入的形状(self.x_shape)。

backward(gy)的参数gy是Variable实例。因此,backward(gy)必须使用DeZero函数对Variable实例进行计算。这里用到了正在实现的reshape函数。

接下来,按如下方式实现reshape函数。

dezero/functions.py

fromdezero.core import as_variable defreshape(x,shape):
if x.shape == shape: return as_variable(x) return Reshape(shape)(x)

函数的参数x应为ndarray实例或Variable实例。如果x.shape == shape,则函数直接返回x。不过,为了确保reshape函数返回Variable实例,这里使用as_variable函数将x转换为Variable实例。另外,as_variable函数已经在步骤21中实现了。如果x为ndarray实例,as_variable(x)会将x转换为Variable实例并返回;如果x是Variable实例,as_variable(x)会直接返回x。

DeZero函数的输入是Variable实例或ndarray实例,输出是Variable实例。如果函数继承自Function类(如Reshape),ndarray实例会在该函数类的__call__方法中自动转换为Variable实例。

下面使用一下刚刚实现的reshape函数。

steps/step38.py

import numpy as np   
fromdezero import Variable   
importdezero-functionsasF   
 $\mathbf{x} =$  Variable(np.array([[1,2,3],[4,5,6]]))   
y  $=$  F.reshape(x,(6,))   
y.backup(retain_grad=True)   
print(x.grad)

运行结果

variable([[1 1 1] [1 1 1]])

上面的代码使用reshape函数来改变形状,然后调用y.backup(ret_grad=True)来求x的梯度。此时,y的梯度被自动补全。补全的梯度具有与y相同的形状(y.grid.shape == y.shape),是所有元素都为1的张量。下面我们来看一下都有哪些数据在流转。结果如图38-2所示。


图38-2 使用reshape函数进行计算的例子

如图38-2所示,在正向传播的过程中,张量的形状由(2,3)变为(6,);在反向传播的过程中,梯度的形状由(6,)变为(2,3),与正向传播的转换相反。此时可知各变量的data和grad的形状是相同的。以上就是DeZero的reshape函数的实现。下一节我们将研究如何使这个函数更加易用。