python生成器读取数据集
生成器
在 Python 中,使用了 yield 的函数被称为生成器(generator)。跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。
调用一个生成器函数,返回的是一个迭代器对象。
更多信息可以参考菜鸟教程生成器与迭代器
数据集读取
在深度学习中训练模型的过程中读取图片数据,如果将图片数据全部读入内存是不现实的,所以有必要使用生成器来读取数据。
通过列表生成式,我们可以直接创建一个列表。但是,受到内存限制,列表容量肯定是有限的。而且,创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了。
所以,如果列表元素可以按照某种算法推算出来,那我们可以在循环的过程中不断推算出后续的元素。这样就不必创建完整的list,从而节省大量的空间。
示例代码如下:
import numpy as np
import cv2
import os
from PIL import Image
def load_batch(dataset_name, batch_size):
data_dir = os.path.join("./")+dataset_name
imgs = os.listdir(data_dir)
num_of_files = len(imgs)
num_batch = num_of_files // batch_size
for i in range(num_batch):
images = []
data_batch = imgs[batch_size*i:batch_size*(i+1)]
for img_path in data_batch:
(filename, suffix) = img_path.split('.')
if suffix == 'jpg' or suffix == 'png':
img = Image.open(data_dir + "/" + imgs[i])
arr = np.asarray(img)
arr = arr.reshape((512, 512, 1))
images.append(arr)
yield np.array(images)
if __name__ == '__main__':
batch = load_batch('images', 8)
print(batch)
print(next(batch).shape)
print(next(batch).shape)