Skip to content

Dataset About PyTorch

Kelang edited this page Aug 1, 2020 · 5 revisions

这个示例展示原始的numpy array数据在pytorch下封装为Dataset类的数据集

数据准备

直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。保存的时候一行为一个图像信息,便于后续读取。 由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件中存放的是每张图像的名字。 xxx_label.txt文件中存放的是类别标记。

def LoadData(root_path, base_path, training_path, test_path):
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_baseset = np.concatenate((x_train, x_test))
    y_baseset = np.concatenate((y_train, y_test))
    train_num = len(x_train)
    test_num = len(x_test)

    # baseset
    file_img = open((os.path.join(root_path, base_path) + 'baseset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, base_path) + 'baseset_label.txt'), 'w')
    for i in range(train_num + test_num):
        file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_baseset[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
        matplotlib.image.imsave(root_path + base_path + 'img/' + str(i) + '.png', x_baseset[i])
    file_img.close()
    file_label.close()

    # trainingset
    file_img = open((os.path.join(root_path, training_path) + 'trainingset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, training_path) + 'trainingset_label.txt'), 'w')
    for i in range(train_num):
        file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_train[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
        matplotlib.image.imsave(root_path + training_path + 'img/' + str(i) + '.png', x_train[i])
    file_img.close()
    file_label.close()

    # testset
    file_img = open((os.path.join(root_path, test_path) + 'testset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, test_path) + 'testset_label.txt'), 'w')
    for i in range(test_num):
        file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_test[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
        matplotlib.image.imsave(root_path + test_path + 'img/' + str(i) + '.png', x_test[i])
    file_img.close()
    file_label.close()

展示Dataset用法

定义自己的Dataset类,PyTorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数 据封装成Dataset类,继承该类需要写初始化方法__init__,获取指定下标数据的方法__getitem__, 获取数据个数的方法__len__。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
    def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.imagedata_path = imgdata_path
        img_file = open((root_path + imgfile_path), 'r')
        self.image_name = [x.strip() for x in img_file]
        img_file.close()
        label_file = open((root_path + labelfile_path), 'r')
        label = [int(x.strip()) for x in label_file]
        label_file.close()
        self.label = torch.LongTensor(label)  # 这句很重要,一定要把label转为LongTensor类型的

    def __getitem__(self, idx):
        image = Image.open(str(self.image_name[idx]))
        image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        label = self.label[idx]
        return image, label

    def __len__(self):
        return len(self.image_name)
import os
import matplotlib
import matplotlib.image as image
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import torch
import scipy.misc
import tensorflow as tf

root_path = './mnist_np2dataset/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

# LoadData(root_path, base_path, training_path, test_path)
training_imgfile = training_path + 'trainingset_img.txt'
training_labelfile = training_path + 'trainingset_label.txt'
training_imgdata = training_path + 'img/'

#实例化一个类
dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)
name = dataset.image_name
print(name[0])

# 获取固定下标的图像
im, label = dataset.__getitem__(0)
print("type im:",type(im))
print("type label:",type(label))
Clone this wiki locally