11.1_修改Function类
11.1 修改Function类
现在修改Function类以支持多个输入和输出。为此,我们考虑将变量放入一个列表(或元组)中进行处理。换言之,修改后的Function类像之前一样接收“一个参数”并返回“一个值”。不同的是,参数和返回值被修改为列表,列表中包含需要的变量。

Python的列表和元组能保存多条数据。列表用[]将数据括起来,如[1,2,3];元组用()将数据括起来,如(1,2,3)。列表和元组的主要区别是,元组一旦创建,其元素就不能改变了。例如,对于元组 不能用 等方法来改变元素,但如果是列表,就可以改变。
首先回顾一下前面已经实现的Function类,代码如下所示。
steps/step10.py
class Function: def__call__(self,input): x $=$ input.data y $=$ self.forward(x) output $=$ Variable(as_array(y)) output.set creator(self) self.input $\equiv$ input self.output $\equiv$ output return output def forward(self,x): raise NotImplementedError() def backward(self,gy): raise NotImplementedError()Function的__call__方法将实际数据从Variable这个“箱子”里取出,然后通过forward方法进行具体的计算。然后,它把结果封装在Variable中,并让结果记住Function是它的“创造者”。在此基础上,我们将__call__方法的参数和返回值修改为列表。
steps/step11.py
class Function: def__call__(self,inputs): xs $=$ [x.data for x in inputs] ys $=$ self.forward(xs) outputs $=$ [Variable(as_array(y))for y in ys] for output in outputs: output.set creator(self) self.Inputs $\equiv$ inputs self.outputs $\equiv$ outputs return outputs
def forward(self,xs): raise NotImplementedError()
def backward(self,gys): raise NotImplementedError()上面的代码将参数和返回值改为列表。除了将变量放入列表进行处理这一点,其余处理的逻辑与之前的一样。另外,这里使用了列表生成式写法创建了新的列表。

列表生成式写法具体来说就是xs = [x.data for x in inputs]这样的写法,此处示例表示对于inputs列表中的各元素x取出相应的数据(x.data),并创建一个由这些元素组成的新列表。
以上就是新的Function类的代码。接下来,我们使用这个新的Function类来实现一个具体的函数。首先实现执行加法运算的Add类。