12.2_第2项改进:使函数更容易实现
12.2 第2项改进:使函数更容易实现
第2项改进针对的是实现Add类的人。目前,要实现Add类,需要编写图12-2左侧所示的代码。
class Add(Function): def forward(self, xs): x0, x1 = xs y = x0 + x1 return (y,)
图12-2 现在的代码(左)和改进后的代码(右)
class Add(Function): def forward(self, x0, x1): y = x0 + x1 return y如图12-2左侧的代码所示,具体的处理编写在Add类的forward方法中。在这个实现中,参数以列表的形式传递,返回值以元组的形式返回。当然,图12-2右侧的代码更为理想。在该代码下,forward方法的参数直接接收变量,直接返回结果变量。第2项改进就是实现这样的代码。
我们按如下方式修改Function类,完成第2项改进。
steps/step12.py
class Function: def__call__(self,\*inputs): xs $=$ [x.data for x in inputs] ys $=$ self.forward(\*xs)# $①$ 使用星号解包 if not isinstance(ys,tuple): # $②$ 对非元组情况的额外处理 ys $=$ (ys,) outputs $=$ [Variable(as_array(y))for y in ys] for output in outputs: output.set creator(self) self.Inputs $\equiv$ inputs self.outputs $\equiv$ outputs return outputs if len(outputs) $>1$ else outputs[0]首先是①处的self.forward(*xs)。这里在调用函数时在参数前加上了星号,由此可解包列表。解包是指将列表中的元素展开并将这些元素作为参数传递的过程。例如在xs = [x0, x1]的情况下,调用self.forward(*xs)就相当于调用self.forward(x0, x1)。
接着是②处。如果ys不是元组,就把它修改为元组。这样在forward方法的实现中,如果返回的元素只有1个,就可以直接返回这个元素。基于这些修改,我们可以按如下方式实现Add类。
steps/step12.py
class Add(Function): def forward(self, x0, x1): y = x0 + x1 return y上面的代码定义了def forward(self, x0, x1):。此外,结果可以写成return y这种只返回一个元素的形式,这样对实现Add类的人来说,DeZero就更好写了。到这里,我们就完成了第2项改进。