python生成器读取数据集

生成器

在 Python 中,使用了 yield 的函数被称为生成器(generator)。跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。

调用一个生成器函数,返回的是一个迭代器对象。

更多信息可以参考菜鸟教程生成器与迭代器

数据集读取

在深度学习中训练模型的过程中读取图片数据,如果将图片数据全部读入内存是不现实的,所以有必要使用生成器来读取数据。

通过列表生成式,我们可以直接创建一个列表。但是,受到内存限制,列表容量肯定是有限的。而且,创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了。

所以,如果列表元素可以按照某种算法推算出来,那我们可以在循环的过程中不断推算出后续的元素。这样就不必创建完整的list,从而节省大量的空间。

示例代码如下:

import numpy as np
import cv2
import os
from PIL import Image

def load_batch(dataset_name, batch_size):
    data_dir = os.path.join("./")+dataset_name
    imgs = os.listdir(data_dir)
    num_of_files = len(imgs)
    num_batch = num_of_files // batch_size
    for i in range(num_batch):
        images = []
        data_batch = imgs[batch_size*i:batch_size*(i+1)]
        for img_path in data_batch:
            (filename, suffix) = img_path.split('.')
            if suffix == 'jpg' or suffix == 'png':
                img = Image.open(data_dir + "/" + imgs[i])
                arr = np.asarray(img)
                arr = arr.reshape((512, 512, 1))
                images.append(arr)
        yield np.array(images)

if __name__ == '__main__':
    batch = load_batch('images', 8)
    print(batch)
    print(next(batch).shape)
    print(next(batch).shape)