26.2_从计算图转换为DOT语言
26.2 从计算图转换为DOT语言
下面我们开始实现前面探讨的内容。在实现get.dot_graph函数之前,首先要实现辅助函数.dot_var。函数名前面的_(下划线)表示我们打算只在本地使用这个函数,即只用于get.dot_graph函数。下面是.dot_var函数的代码和它的使用示例。
dezero/util.py
def __dot_var(v, verbose=False):
dot_var = '(' [label="{}", color=orange, style=filled]\n'
name = '' if v.name is None else v.name
if verbose and v.data is not None:
if v.name is not None:
name += ':'
name += str(v.shape) + ' ' + str(vdtype)
return dot_var.format(id(v), name)使用示例
$\mathbf{x} =$ Variable(np.random.randint(2,3))
x.name $=$ 'x'
print(_dot_var(x))
print(_dot_var(x,verbose=True))运行结果
4423761088 [label="x", color=orange, style=filled]
4423761088 [label="x: (2, 3) float64", color=orange, style=filled]前面的代码将一个Variable实例赋给_var函数,函数返回以DOT语言编写的表示实例信息的字符串。为了使指定的变量节点的ID唯一,这里使用了Python内置的id函数。使用id函数可以得到对象的ID。对象的ID是该对象特有的,因此,在用DOT语言时,我们可以将对象的ID用作节点的ID。
上面的代码还使用了format方法来操作字符串。format方法将字符串{}的部分替换成作为format参数传来的对象(字符串和整数等)的值。

.dot_var函数中有一个名为verbose的参数。当verbose为True时,_dot_var函数会将ndarray实例的形状和类型也作为标签输出。
下面实现一个能将DeZero的函数转换为DOT语言的工具函数。该函数名为.dot_func,代码如下所示。
dezero/util.py
def __dot_func(f):
dot_func = '\{ [label="{}", color=lightblue, style=filled, shape=box]\n'
txt = dot_func.format(id(f), f.__class__._name__)
dot_edge = '\{ -> \}\n'
for x in finputs:
txt += dot_edge.format(id(x), id(f))
for y in f.output:
txt += dot_edge.format(id(f), id(y))) # y是weakref
return txt使用示例
$\mathbf{x}\theta =$ Variable(np.array(1.0))
x1 $=$ Variable(np.array(1.0))
y $=$ x0 + x1
txt $=$ _dot_func(y creator)
print(txt)运行结果
4423742632 [label="Add", color=lightblue, style=filled, shape=box]
4403357456 -> 4423742632
4403358016 -> 4423742632
4423742632 -> 4423761088.dot_func函数用DOT语言记述了DeZero的函数。此外,它用DOT语言记述了函数与输入变量之间的连接,以及函数与输出变量之间的连接。回顾一下前面的内容:DeZero函数继承自Function类,它拥有实例变量inputs和outputs(图26-2)。

图26-2 Function类的inputs和outputs
准备工作完成了,下面实现get.dot_graph函数。我们可以参考Variable类的backward方法,编写出来的代码如下所示。
dezero/util.py
def get.dot_graph(output, verbose=True):
txt = ""
funcs = []
seen_set = set()
def add_func(f):
if f not in seen_set:
funcs.append(f)
#funcs.sort(key=lambda x: x_generation)
seen_set.add(f)add_func(output creator)
txt += _dot_var(output, verbose)
while funcs:
func = funcs.pop()
txt += _dot_func(func)
for x in func.inputs:
txt += _dot_var(x, verbose)
if xCreator is not None:
add_func(xCreator)
return 'digraph g{\n' + txt + '}上面代码的逻辑与Variable类的backward方法基本相同(阴影部分是与backward方法的实现不同的地方)。backward方法传播的是导数,但这里没有传播导数,而是向txt添加用DOT语言编写的字符串。
另外,在实际的反向传播中,节点的遍历顺序也很重要。为此,我们赋予了函数一个generation(辈分)整数值,并按照该值从大到小的顺序取出函数(详见步骤15和步骤16)。但在get_DOT_graph函数中,节点遍历的顺序并不重要,所以我们注释掉了按generation的值排序的代码。

这里需要关注的是“存在哪些节点”“哪个节点与哪个节点相连”。也就是说,节点的遍历顺序并不重要,所以我们不需要使用根据generation的值优先取出某些节点的机制。
计算图可视化的代码到此就全部完成了。下面添加一个能使计算图的可视化操作更为简单的函数。