11.2_Add类的实现

11.2 Add类的实现

下面实现Add类的forward方法。需要注意的是参数和返回值应该是列表(或元组)。为了满足这一点,我们需要将代码编写成下面这样。

steps/step11.py

class Add(Function): def forward(self, xs): x0, x1 = xs y = x0 + x1 return (y,)

Add类的参数是包含两个变量的列表,所以通过 x0x0x1=xsx1 = xs 可以取出xs列表的元素。然后使用这些元素进行计算。在返回结果时,使用return(y,)(也可以写成“return y,”)来返回一个元组。这么处理后,我们就可以像下面这样使用Add类了。

steps/step11.py

xs = [Variable(np.array(2)), Variable(np.array(3))] # 初始化为列表  
f = Add()  
ys = f(xs) # ys是元组  
y = ys[0]  
print(y.data)

运行结果

5

如上面的代码所示,DeZero能正确地计算出 2+3=52 + 3 = 5 。输入变成列表后,DeZero可以处理多个变量;输出变成元组后,DeZero可以支持多个变量。现在的正向传播支持可变长的参数和返回值了,不过实现代码有些烦琐,因为使用Add类的人需要准备列表作为输入变量,并接收元组作为返回值。这种用法很别扭。在下一个步骤,我们将改进目前的实现,使代码更加自然。

步骤12

可变长参数(改进篇)

在上一个步骤,我们扩展了DeZero以支持可变长参数。不过,代码仍有改进空间。为了提高DeZero的易用性,这里对它进行两项改进。第1项改进针对的是使用Add类(或具体的函数类)的人,第2项改进针对的是实现Add类的人。先看第1项改进。