47.1_用于切片操作的函数

47.1 用于切片操作的函数

首先要增加一个工具函数,函数的名字是get_item。本节只展示该函数的使用方法,对具体实现感兴趣的读者可以参考附录B。下面是get_item函数的使用示例。

import numpy as np   
fromdezero import Variable   
importdezero-functionsasF   
 $\mathbf{x} =$  Variable(np.array([[1,2,3],[4,5,6]]))   
y  $=$  F.get_item(x,1)   
print(y)

运行结果

variable([456])

上面的代码使用get_item函数从Variable的多维数组中提取出一部分

元素。这里从形状为(2,3)的 xx 中提取了第一行的元素。这个函数被实现为DeZero函数,这意味着它的反向传播也能正确进行。我们可以试着紧跟上面的代码写出如下代码。

y.backup() print(x.grad)

运行结果

variable([[0.0.0.][1.1.1.]])

上面的代码调用y.backup()进行反向传播(此时通过y.grad = Variable(np.ones_like(y.data))自动补充梯度)。切片所做的计算是将多维数组中的一些数据原封不动地传递出去。因此,这个反向传播为多维数组中被提取的部分设置梯度,并将其余部分设置为0。图47-1展示了这个过程。


图47-1 get_item函数的正向传播和反向传播的示例

提取多维数组部分元素的操作叫作切片(slice)。在Python中,我们可以通过编写 x[1]x[1]x[1:4]x[1:4] 这样的代码对列表或元组执行切片操作。

我们也可以使用get_item函数多次提取同一组元素,代码如下所示。

$\mathbf{x} =$  Variable(np.array([[1,2,3],[4,5,6]])) indices  $=$  np.array([0,0,1])   
y  $=$  F.get_item(x,indices)   
print(y)

运行结果

variable([[1 2 3] [1 2 3] [4 5 6])

以上就是对DeZero的get_item函数的介绍。接下来进行设置,使get_item函数也可以作为Variable的方法使用。代码如下所示。

Variable._*_item_ = F.get_item  
y = x[1]  
print(y)  
y = x[:,2]  
print(y)

运行结果

variable([456])  
variable([36])

上面用于设置的代码是Variable.getitem = get_item。用x[1]或x[:,2]等写法编写的代码在运行时会调用get_item函数,该切片操作的反向传播也能正确进行。这个特殊方法的设置在dezero/core.py的setup_variable函数中被调用(setup_variable函数是DeZero在初始化时调用的函数)。这样就能对Variable实例自由地进行切片操作了。下面开始本步骤的主要内容。