58.2_已训练的权重数据
58.2 已训练的权重数据
VGG16是在大型数据集ImageNet上训练的,训练后的权重数据已开放下载。这里向刚才实现的VGG16类中添加用于加载已训练的权重数据的函数。

VGG16模型基于Creative Commons Attribution许可协议开放下载。另外,为了使DeZero能够读取模型原始的权重数据而对该模型施加了微小修改的权重文件可从GitHub上获得。
向VGG16类添加的代码如下所示
dezero/models.py
fromdezero importutils
class VGG16(Model): WEIGHTS_PATH $\equiv$ 'https://github.com/koki0702/dezero-models/' releases/download/v0.1/vgg16.npz' def__init__(self,pretrained $\equiv$ False): ... ifpretrained: weights_path $=$ utils.get_file(VGG16.WEIGHTS_PATH) self.loadweights(weights_path)上面的代码在VGG16类的初始化方法中添加了参数pretrained=False。如果参数为True,则从指定位置下载并读取权重文件(DeZero专用的转换过的权重文件)。加载权重文件是步骤53中增加的功能。

dezero/util.py中有一个get_file函数。该函数从指定的URL下载文件,然后返回下载文件(在PC上)的绝对路径。如果下载的文件已经在缓存目录中,它会返回该文件的绝对路径。DeZero的缓存目录是~/.dezero。
以上就是VGG16类的实现。VGG16类的代码在dezero/models.py中。下面是使用已训练的VGG16的实例代码。
import numpy as np
fromdezero.models import VGG16
model $=$ VGG16(pretrained $\equiv$ True)
$\mathbf{x} =$ np.random.randint(1,3,224,224).astype(np.float32)#虚拟数据
model.plot(x)为了实现可视化操作,上面的代码还创建了VGG16的计算图。结果如图58-2所示。

图58-2 VGG16的计算图