Mnist数据集图片化处理

Step1: 下载mnist数据集

可以去官网下载,也可以用我的我的百度云分享下载;

链接:https://pan.baidu.com/s/13MwGxwNkfvY85ISxaCAkrQ
提取码:pzn0

Step2:提取图片

# -*- coding: utf-8 -*-
"""
Created on 2018-10-17
@author: Angus Cai
"""

import os
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def load_mnist_image(path, filename, type = 'train'):
    full_name = os.path.join(path, filename)
    fp = open(full_name, 'rb')
    buf = fp.read()
    index = 0;
    magic, num, rows, cols = struct.unpack_from('>IIII', buf, index)
    index += struct.calcsize('>IIII')

    for image in range(0, num):
        im = struct.unpack_from('>784B', buf, index)
        index += struct.calcsize('>784B')
        im = np.array(im, dtype = 'uint8')
        im = im.reshape(28, 28)
        im = Image.fromarray(im)
        if (type == 'train'):
            isExists = os.path.exists('./train')
            if not isExists:
                os.mkdir('./train')
            im.save('./train/train_%s.jpeg' %image, 'jpeg')
        if (type == 'test'):
            isExists = os.path.exists('./test')
            if not isExists:
                os.mkdir('./test')
            im.save('./test/test_%s.jpeg' %image, 'jpeg')

def load_mnist_label(path, filename, type = 'train'):
    full_name = os.path.join(path, filename)
    fp = open(full_name, 'rb')
    buf = fp.read()
    index = 0;
    magic, num = struct.unpack_from('>II', buf, index)
    index += struct.calcsize('>II')
    Labels = np.zeros(num)

    for i in range(num):
        Labels[i] = np.array(struct.unpack_from('>B', buf, index))
        index += struct.calcsize('>B')

    if (type == 'train'):
        np.savetxt('./train_labels.csv', Labels, fmt='%i', delimiter=',')
    if (type == 'test'):
        np.savetxt('./test_labels.csv', Labels, fmt='%i', delimiter=',')

    return Labels

if __name__ == '__main__':
    path = '.\\MNIST_data' # Mnist数据集所在路径
    train_images = 'train-images.idx3-ubyte'
    load_mnist_image(path, train_images, 'train')
    train_labels = 'train-labels.idx1-ubyte'
    load_mnist_label(path, train_labels, 'train')
    test_images = 't10k-images.idx3-ubyte'
    load_mnist_image(path, test_images, 'test')
    test_labels = 't10k-labels.idx1-ubyte'
    load_mnist_label(path, test_labels, 'test')

Step3:分类

# -*- coding: utf-8 -*-
"""
Created on 2018-10-17
@author: Angus Cai
"""
import shutil
import numpy as np
import os
import csv

image_path = ".\\test\\"
dest0 = ".\\0\\"
dest1 = ".\\1\\"
dest2 = ".\\2\\"
dest3 = ".\\3\\"
dest4 = ".\\4\\"
dest5 = ".\\5\\"
dest6 = ".\\6\\"
dest7 = ".\\7\\"
dest8 = ".\\8\\"
dest9 = ".\\9\\"


label_path = "./"

csvFile = open("test_labels.csv", "r")
labels = csv.reader(csvFile)

for index, label in enumerate(labels):
	if int("".join(label)) == 0:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest0+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 1:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest1+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 2:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest2+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 3:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest3+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 4:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest4+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 5:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest5+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 6:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest6+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 7:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest7+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 8:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest8+"test_"+str(index)+".jpeg")
	if int("".join(label)) == 9:
		shutil.move(image_path+"test_"+str(index)+".jpeg", dest9+"test_"+str(index)+".jpeg")
	print(index)
	# print("".join(label))

分类好的Mnist图片数据集下载

链接:https://pan.baidu.com/s/1JGKEGudXBaFEK8eBuZ-EPw
提取码:4wnb