50.3_accuracy函数的实现

50.3 accuracy函数的实现

这里添加一个用于评估识别精度的函数accuracy。代码如下所示。

dezero/functions.py

def accuracy(y, t):
    y, t = as_variable(y), as_variable(t)
    pred = y.data.argmax(axis=1).reshape(t.shape)
    result = (pred == t.data)
    acc = result.mean()
    return Variable(as_array(acc))

accuracy函数用于计算参数y相对于t的“正确率”。其中,y是神经网络的预测结果,t是正确答案的数据。这两个参数是Variable实例或ndarray实例。

函数内部首先求出神经网络的预测结果pred。为此需要找出神经网络预测结果最大的索引,并进行reshape。然后将pred与正确答案的数据t进行比较,结果是True/False张量(ndarray)。计算张量为True的数据所占的比例(求平均值),得到的值就相当于正确率。

accuracy函数返回的是Variable实例,但函数内部的计算是针对ndarray实例进行的。因此,不能对accuracy函数求导。

上面的最后一行代码return Variable(as_array(acc))使用了as_array函数。这是因为acc(= result.mean())的数据类型是np.float64或np.float32。使用as_array函数对acc进行转换,会得到ndarray实例(as_array函数已在步骤9中实现)。

下面是使用accuracy函数计算正确率(识别精度)的示例。

import numpy as np   
importdezero-functionsasF   
y  $\equiv$  np.array([0.2,0.8,0],[0.1,0.9,0],[0.8,0.1,0.1])   
t  $=$  np.array([1,2,0])   
acc  $=$  F.accuracy(y,t)   
print(acc)

运行结果

variable(0.6666666666666666)

上面代码中的y是神经网络对3个样本数据的预测结果(这是一个3类分类),训练数据t是每个样本数据的正确答案的索引。通过accuracy函数计算出的识别精度为0.66...。