10.2_square函数反向传播的测试

10.2 square函数反向传播的测试

接下来添加针对square函数反向传播的测试。在刚刚实现的SquareTest类中添加以下代码。

steps/step10.py

class SquareTest(unittest.TestCase): def test_backward(self): x = Variable(np.array(3.0)) y = square(x) yopsis expected  $=$  np.array(6.0) self.assertEqual(x.grad,expected)

我们在代码中添加了一个名为 test_backward 的方法。在方法中通过 yopsis 来求导,然后检查导数的值是否等于预期值。另外,代码中设置的预期值(expected)6.0 是手动计算出来的。

现在再来测试一下上面的代码。输出结果如下所示

..   
Ran 2 tests in 0.001s   
OK

从结果来看,两项测试都通过了。之后可以采用和前面一样的做法来添加其他的测试用例(输入和预期值)。随着测试用例的增加,square函数的可靠性也会增加。另外,我们也可以在修改代码时进行测试,以此来反复验证square函数的状态。