-
Notifications
You must be signed in to change notification settings - Fork 115
Dataset About PyTorch
Kelang edited this page Aug 1, 2020
·
5 revisions
这个示例展示原始的numpy array数据在pytorch下封装为Dataset类的数据集
直接通过keras.dataset
加载mnist
数据集,不能自动下载的话可以手动下载.npz
并保存至相应目录下。保存的时候一行为一个图像信息,便于后续读取。
由于mnist
数据集其实是灰度图,这里用matplotlib
保存的图像是伪彩色图像。如果用scipy.misc.imsave
的话保存的则是灰度图像。
xxx_img.txt
文件中存放的是每张图像的名字。
xxx_label.txt
文件中存放的是类别标记。
def LoadData(root_path, base_path, training_path, test_path):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_baseset = np.concatenate((x_train, x_test))
y_baseset = np.concatenate((y_train, y_test))
train_num = len(x_train)
test_num = len(x_test)
# baseset
file_img = open((os.path.join(root_path, base_path) + 'baseset_img.txt'), 'w')
file_label = open((os.path.join(root_path, base_path) + 'baseset_label.txt'), 'w')
for i in range(train_num + test_num):
file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_baseset[i]) + '\n') # label
# scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
matplotlib.image.imsave(root_path + base_path + 'img/' + str(i) + '.png', x_baseset[i])
file_img.close()
file_label.close()
# trainingset
file_img = open((os.path.join(root_path, training_path) + 'trainingset_img.txt'), 'w')
file_label = open((os.path.join(root_path, training_path) + 'trainingset_label.txt'), 'w')
for i in range(train_num):
file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_train[i]) + '\n') # label
# scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
matplotlib.image.imsave(root_path + training_path + 'img/' + str(i) + '.png', x_train[i])
file_img.close()
file_label.close()
# testset
file_img = open((os.path.join(root_path, test_path) + 'testset_img.txt'), 'w')
file_label = open((os.path.join(root_path, test_path) + 'testset_label.txt'), 'w')
for i in range(test_num):
file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_test[i]) + '\n') # label
# scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
matplotlib.image.imsave(root_path + test_path + 'img/' + str(i) + '.png', x_test[i])
file_img.close()
file_label.close()
定义自己的Dataset
类,PyTorch
训练数据时需要数据集为Dataset
类,便于迭代等等,这里将加载保存之后的数
据封装成Dataset
类,继承该类需要写初始化方法__init__
,获取指定下标数据的方法__getitem__
,
获取数据个数的方法__len__
。这里尤其需要注意的是要把label
转为LongTensor
类型的。
class DataProcessingMnist(Dataset):
def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform=None):
self.root_path = root_path
self.transform = transform
self.imagedata_path = imgdata_path
img_file = open((root_path + imgfile_path), 'r')
self.image_name = [x.strip() for x in img_file]
img_file.close()
label_file = open((root_path + labelfile_path), 'r')
label = [int(x.strip()) for x in label_file]
label_file.close()
self.label = torch.LongTensor(label) # 这句很重要,一定要把label转为LongTensor类型的
def __getitem__(self, idx):
image = Image.open(str(self.image_name[idx]))
image = image.convert('RGB')
if self.transform is not None:
image = self.transform(image)
label = self.label[idx]
return image, label
def __len__(self):
return len(self.image_name)
import os
import matplotlib
import matplotlib.image as image
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import torch
import scipy.misc
import tensorflow as tf
root_path = './mnist_np2dataset/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'
# LoadData(root_path, base_path, training_path, test_path)
training_imgfile = training_path + 'trainingset_img.txt'
training_labelfile = training_path + 'trainingset_label.txt'
training_imgdata = training_path + 'img/'
#实例化一个类
dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)
name = dataset.image_name
print(name[0])
# 获取固定下标的图像
im, label = dataset.__getitem__(0)
print("type im:",type(im))
print("type label:",type(label))