50.1_什么是迭代器

50.1 什么是迭代器

顾名思义,迭代器可以重复地(迭代)提取元素。Python的迭代器提供了从列表和元组等具有多个元素的数据类型中依次取出数据的功能,具体示例如下。

>>> t = [1, 2, 3]
>>> x = iter(t)
>>> next(x)
1
>>> next(x)
2
>>> next(x)
3
>>> next(x)
Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
StopIteration

我们可以使用iter函数将列表转换成迭代器。上面的代码基于列表t创建了迭代器x。next函数用于从迭代器中按顺序取出数据。在上面的例子中,每次执行next函数,函数都会依次取出列表中的元素。在第4次执行next时,由于下一个元素不存在,所以出现了StopIteration异常。

使用for语句从列表中取出元素时,其内部(用户看不见的地方)使用了迭代器功能。假设有 t=[1,2,3]t = [1, 2, 3] ,那么在运行for x in t: x时,列表t会在语句内部转换为迭代器。

我们也可以创建一个Python迭代器。以下是自制的迭代器代码。

class MyIterator: def __init__(self, max_cnt): self.max_cnt = max_cnt self_cnt = 0 def __iter__(self): return self def __next__(self): if self_cnt == self.max_cnt: raise StopIteration() self_cnt += 1 return self_cnt

上面是MyIterator类的代码。为了使这个类作为Python的迭代器使用,我们实现了特殊方法__iter__,它会返回自身(self)。然后实现了特殊方法__next__,它将返回下一个元素。如果没有要返回的元素,则执行raise StopIteration()。这样,MyIterator的实例就可以作为迭代器使用了。下面是它的使用示例。

obj  $=$  MyIterator(5)   
for  $\mathbf{x}$  in obj: print(x)

运行结果

1   
2   
3   
4   
5

上面的代码使用for ×\times in obj:语句取出了元素。下面利用迭代器的机制,实现用于取出DeZero的小批量数据的DataLoader类。这个类从给定的数据集中按顺序从头开始取出数据,并根据需要重排数据集。DataLoader的代码如下所示。

dezero/datalogcers.py

import math   
import random   
import numpy as np   
class DataLoader: def__init__(self,dataset,batch_size,shuffle=True): self(dataset  $=$  dataset self.batch_size  $=$  batch_size selfshuffle  $=$  shuffle self.data_size  $=$  len(dataset) self.max_iter  $=$  math.ceil(self.data_size/batch_size) self.reset() def reset(self): self_iteration  $= 0$  if selfshuffle: self.index  $=$  np.random.permutation(len(self(dataset)) else: self.index  $=$  np.arange(len(self(dataset)) def__iter__(self): return self
def __next__(self):
    if self_iteration >= self.max_iter:
        self.reset()
        raise StopIteration
    i, batch_size = self_iteration , self.batch_size
    batch_index = self.index[i * batch_size:(i + 1) * batch_size]
    batch = [self(dataset[i] for i in batch_index]
    x = np.array([example[0] for example in batch])
    t = np.array([example[1] for example in batch])
    self_iteration += 1
    return x, t
def next(self):
    return self._next_(   )

这个类在初始化时接收以下参数。

  • dataset: 具有 Dataset 接口的实例

  • batch_size: 小批量数据的大小

  • shuffle: 在每轮训练时是否对数据集进行重排

初始化方法在将参数设置为实例变量的参数之后调用了 reset 方法。reset 方法将 iteration 实例变量的次数设置为 0,并根据需要重排数据的索引。

__next__方法取出小批量数据,并将其转换为ndarray实例。它的代码与前面编写的代码相同,这里不再赘述。

dezero/dataloaders.py 中的 DataLoader 类的代码还包括向 GPU 传输数据的机制。上面的代码省略了支持 GPU 的代码。步骤 52 将介绍相关内容。

最后,在dezero/init.py中添加导入语句fromdezero.dataloggers

import DataLoader。这样,用户就可以通过fromdezero import DataLoader导入DataLoader(不必编写fromdezero@databoaders import DataLoader)。