47.1_用于切片操作的函数
47.1 用于切片操作的函数
首先要增加一个工具函数,函数的名字是get_item。本节只展示该函数的使用方法,对具体实现感兴趣的读者可以参考附录B。下面是get_item函数的使用示例。
import numpy as np
fromdezero import Variable
importdezero-functionsasF
$\mathbf{x} =$ Variable(np.array([[1,2,3],[4,5,6]]))
y $=$ F.get_item(x,1)
print(y)运行结果
variable([456])上面的代码使用get_item函数从Variable的多维数组中提取出一部分
元素。这里从形状为(2,3)的 中提取了第一行的元素。这个函数被实现为DeZero函数,这意味着它的反向传播也能正确进行。我们可以试着紧跟上面的代码写出如下代码。
y.backup() print(x.grad)运行结果
variable([[0.0.0.][1.1.1.]])上面的代码调用y.backup()进行反向传播(此时通过y.grad = Variable(np.ones_like(y.data))自动补充梯度)。切片所做的计算是将多维数组中的一些数据原封不动地传递出去。因此,这个反向传播为多维数组中被提取的部分设置梯度,并将其余部分设置为0。图47-1展示了这个过程。

图47-1 get_item函数的正向传播和反向传播的示例

提取多维数组部分元素的操作叫作切片(slice)。在Python中,我们可以通过编写 或 这样的代码对列表或元组执行切片操作。
我们也可以使用get_item函数多次提取同一组元素,代码如下所示。
$\mathbf{x} =$ Variable(np.array([[1,2,3],[4,5,6]])) indices $=$ np.array([0,0,1])
y $=$ F.get_item(x,indices)
print(y)运行结果
variable([[1 2 3] [1 2 3] [4 5 6])以上就是对DeZero的get_item函数的介绍。接下来进行设置,使get_item函数也可以作为Variable的方法使用。代码如下所示。
Variable._*_item_ = F.get_item
y = x[1]
print(y)
y = x[:,2]
print(y)运行结果
variable([456])
variable([36])上面用于设置的代码是Variable.getitem = get_item。用x[1]或x[:,2]等写法编写的代码在运行时会调用get_item函数,该切片操作的反向传播也能正确进行。这个特殊方法的设置在dezero/core.py的setup_variable函数中被调用(setup_variable函数是DeZero在初始化时调用的函数)。这样就能对Variable实例自由地进行切片操作了。下面开始本步骤的主要内容。