Mnist数据集图片化处理
Step1: 下载mnist数据集
可以去官网下载,也可以用我的我的百度云分享下载;
链接:https://pan.baidu.com/s/13MwGxwNkfvY85ISxaCAkrQ
提取码:pzn0
Step2:提取图片
# -*- coding: utf-8 -*-
"""
Created on 2018-10-17
@author: Angus Cai
"""
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def load_mnist_image(path, filename, type = 'train'):
full_name = os.path.join(path, filename)
fp = open(full_name, 'rb')
buf = fp.read()
index = 0;
magic, num, rows, cols = struct.unpack_from('>IIII', buf, index)
index += struct.calcsize('>IIII')
for image in range(0, num):
im = struct.unpack_from('>784B', buf, index)
index += struct.calcsize('>784B')
im = np.array(im, dtype = 'uint8')
im = im.reshape(28, 28)
im = Image.fromarray(im)
if (type == 'train'):
isExists = os.path.exists('./train')
if not isExists:
os.mkdir('./train')
im.save('./train/train_%s.jpeg' %image, 'jpeg')
if (type == 'test'):
isExists = os.path.exists('./test')
if not isExists:
os.mkdir('./test')
im.save('./test/test_%s.jpeg' %image, 'jpeg')
def load_mnist_label(path, filename, type = 'train'):
full_name = os.path.join(path, filename)
fp = open(full_name, 'rb')
buf = fp.read()
index = 0;
magic, num = struct.unpack_from('>II', buf, index)
index += struct.calcsize('>II')
Labels = np.zeros(num)
for i in range(num):
Labels[i] = np.array(struct.unpack_from('>B', buf, index))
index += struct.calcsize('>B')
if (type == 'train'):
np.savetxt('./train_labels.csv', Labels, fmt='%i', delimiter=',')
if (type == 'test'):
np.savetxt('./test_labels.csv', Labels, fmt='%i', delimiter=',')
return Labels
if __name__ == '__main__':
path = '.\\MNIST_data' # Mnist数据集所在路径
train_images = 'train-images.idx3-ubyte'
load_mnist_image(path, train_images, 'train')
train_labels = 'train-labels.idx1-ubyte'
load_mnist_label(path, train_labels, 'train')
test_images = 't10k-images.idx3-ubyte'
load_mnist_image(path, test_images, 'test')
test_labels = 't10k-labels.idx1-ubyte'
load_mnist_label(path, test_labels, 'test')
Step3:分类
# -*- coding: utf-8 -*-
"""
Created on 2018-10-17
@author: Angus Cai
"""
import shutil
import numpy as np
import os
import csv
image_path = ".\\test\\"
dest0 = ".\\0\\"
dest1 = ".\\1\\"
dest2 = ".\\2\\"
dest3 = ".\\3\\"
dest4 = ".\\4\\"
dest5 = ".\\5\\"
dest6 = ".\\6\\"
dest7 = ".\\7\\"
dest8 = ".\\8\\"
dest9 = ".\\9\\"
label_path = "./"
csvFile = open("test_labels.csv", "r")
labels = csv.reader(csvFile)
for index, label in enumerate(labels):
if int("".join(label)) == 0:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest0+"test_"+str(index)+".jpeg")
if int("".join(label)) == 1:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest1+"test_"+str(index)+".jpeg")
if int("".join(label)) == 2:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest2+"test_"+str(index)+".jpeg")
if int("".join(label)) == 3:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest3+"test_"+str(index)+".jpeg")
if int("".join(label)) == 4:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest4+"test_"+str(index)+".jpeg")
if int("".join(label)) == 5:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest5+"test_"+str(index)+".jpeg")
if int("".join(label)) == 6:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest6+"test_"+str(index)+".jpeg")
if int("".join(label)) == 7:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest7+"test_"+str(index)+".jpeg")
if int("".join(label)) == 8:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest8+"test_"+str(index)+".jpeg")
if int("".join(label)) == 9:
shutil.move(image_path+"test_"+str(index)+".jpeg", dest9+"test_"+str(index)+".jpeg")
print(index)
# print("".join(label))
分类好的Mnist图片数据集下载
链接:https://pan.baidu.com/s/1JGKEGudXBaFEK8eBuZ-EPw
提取码:4wnb