11.1_修改Function类

11.1 修改Function类

现在修改Function类以支持多个输入和输出。为此,我们考虑将变量放入一个列表(或元组)中进行处理。换言之,修改后的Function类像之前一样接收“一个参数”并返回“一个值”。不同的是,参数和返回值被修改为列表,列表中包含需要的变量。

Python的列表和元组能保存多条数据。列表用[]将数据括起来,如[1,2,3];元组用()将数据括起来,如(1,2,3)。列表和元组的主要区别是,元组一旦创建,其元素就不能改变了。例如,对于元组 x=(1,2,3)\mathsf{x} = (1,2,3) 不能用 ×[θ]=4\times [\theta ] = 4 等方法来改变元素,但如果是列表,就可以改变。

首先回顾一下前面已经实现的Function类,代码如下所示。

steps/step10.py

class Function: def__call__(self,input): x  $=$  input.data y  $=$  self.forward(x) output  $=$  Variable(as_array(y)) output.set creator(self) self.input  $\equiv$  input self.output  $\equiv$  output return output def forward(self,x): raise NotImplementedError() def backward(self,gy): raise NotImplementedError()

Function的__call__方法将实际数据从Variable这个“箱子”里取出,然后通过forward方法进行具体的计算。然后,它把结果封装在Variable中,并让结果记住Function是它的“创造者”。在此基础上,我们将__call__方法的参数和返回值修改为列表。

steps/step11.py

class Function: def__call__(self,inputs): xs  $=$  [x.data for x in inputs] ys  $=$  self.forward(xs) outputs  $=$  [Variable(as_array(y))for y in ys] for output in outputs: output.set creator(self) self.Inputs  $\equiv$  inputs self.outputs  $\equiv$  outputs return outputs   
def forward(self,xs): raise NotImplementedError()   
def backward(self,gys): raise NotImplementedError()

上面的代码将参数和返回值改为列表。除了将变量放入列表进行处理这一点,其余处理的逻辑与之前的一样。另外,这里使用了列表生成式写法创建了新的列表。

列表生成式写法具体来说就是xs = [x.data for x in inputs]这样的写法,此处示例表示对于inputs列表中的各元素x取出相应的数据(x.data),并创建一个由这些元素组成的新列表。

以上就是新的Function类的代码。接下来,我们使用这个新的Function类来实现一个具体的函数。首先实现执行加法运算的Add类。