53.2_Layer类参数的扁平化

53.2 Layer类参数的扁平化

首先回顾一下Layer类的层次结构。层次结构是一个嵌套的结构,Layer中还有别的Layer。具体示例如下所示。

layer  $=$  Layer()   
l1  $=$  Layer()   
l1.p1  $=$  Parameter(np.array(1))   
layer.l1  $=$  l1   
layer.p2  $=$  Parameter(np.array(2))   
layer.p3  $=$  Parameter(np.array(3))

上面的 layer 中包含另一个层 l1。图 53-1 是这个层次结构的可视化图形。


图53-1 Layer类的层次结构

现在考虑从图53-1所示的层次结构中将Parameter作为一个扁平的、非嵌套的字典取出。为此要在Layer类中添加一个名为 Flatten_parameters的方法。我们先来看看这个方法的用法。

params_dict = {}  
layer._ flatten.params.params_dict)  
print.params_dict)

运行结果

{'p2': variable(2), 'l1/p1': variable(1), 'p3': variable(3)}

上面的代码准备了字典params_dict = {}, 并将其作为参数传给函数, 即 layer._ flatten.params(param_dict)。然后, layer中包含的参数被“扁平地”取出。实际上, l1层中的p1参数是用l1/p1这个键存储的。下面是_flatten.params 方法的代码。

dezero/layers.py

class Layer: def flatten.params(self, params_dict, parent_key=""): for name in self._params: obj  $=$  self._dict_[name] key  $=$  parent_key  $+ \prime / +$  name if parent_key else name if isinstance(obj, Layer): obj._ flatten.params.params_dict, key) else: params_dict[key]  $=$  obj

这个方法接收的参数是字典params_dict和文本parent_key。顺带提一下,Layer类的实例变量.params保存的是Parameter的实例变量名称或者Layer的实例变量名称。因此我们需要通过obj = self.dict[name]将实际的对象取出。之后,如果取出的obj是Layer,则调用该obj的_flatten.params方法。通过(递归)调用,我们就能以扁平的结构取出Parameter。