Skip to content

Commit 46df7a7

Browse files
authored
release load SVHN (#422)
* release load SVHN * fixed codacy * fix liguo suggestion
1 parent a49f6c2 commit 46df7a7

File tree

4 files changed

+130
-18
lines changed

4 files changed

+130
-18
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ TensorLayer is a deep learning and reinforcement learning library on top of [Ten
2424
- Useful links: [Documentation](http://tensorlayer.readthedocs.io), [Examples](http://tensorlayer.readthedocs.io/en/latest/user/example.html), [中文文档](https://tensorlayercn.readthedocs.io), [中文书](http://www.broadview.com.cn/book/5059)
2525

2626
# News
27+
* [16 Mar] Release experimental APIs for binary networks.
2728
* [18 Jan] [《深度学习:一起玩转TensorLayer》](http://www.broadview.com.cn/book/5059) (Deep Learning using TensorLayer)
2829
* [17 Dec] Release experimental APIs for distributed training (by [TensorPort](https://tensorport.com)). See [tiny example](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_mnist_distributed.py).
2930
* [17 Nov] Release data augmentation APIs for object detection, see [tl.prepro](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#object-detection).

docs/modules/files.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ API - Files
99
load_mnist_dataset
1010
load_fashion_mnist_dataset
1111
load_cifar10_dataset
12+
load_cropped_svhn
1213
load_ptb_dataset
1314
load_matt_mahoney_text8_dataset
1415
load_imdb_dataset
@@ -63,6 +64,10 @@ CIFAR-10
6364
^^^^^^^^^^^^
6465
.. autofunction:: load_cifar10_dataset
6566

67+
SVHN
68+
^^^^^^^
69+
.. autofunction:: load_cropped_svhn
70+
6671
Penn TreeBank (PTB)
6772
^^^^^^^^^^^^^^^^^^^^^
6873
.. autofunction:: load_ptb_dataset

tensorlayer/files.py

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import sys
4040
import tarfile
4141
import zipfile
42+
import time
4243

4344
import numpy as np
4445
import tensorflow as tf
@@ -320,6 +321,106 @@ def unpickle(file):
320321
return X_train, y_train, X_test, y_test
321322

322323

324+
def load_cropped_svhn(path='data', include_extra=True):
325+
"""Load Cropped SVHN.
326+
327+
The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images.
328+
Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see `ufldl website <http://ufldl.stanford.edu/housenumbers/>`__.
329+
330+
Parameters
331+
----------
332+
path : str
333+
The path that the data is downloaded to.
334+
include_extra : boolean
335+
If True (default), add extra images to the training set.
336+
337+
Returns
338+
-------
339+
X_train, y_train, X_test, y_test: tuple
340+
Return splitted training/test set respectively.
341+
342+
Examples
343+
---------
344+
>>> X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)
345+
>>> tl.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')
346+
347+
"""
348+
349+
import scipy.io
350+
351+
start_time = time.time()
352+
353+
path = os.path.join(path, 'cropped_svhn')
354+
logging.info("Load or Download Cropped SVHN > {} | include extra images: {}".format(path, include_extra))
355+
url = "http://ufldl.stanford.edu/housenumbers/"
356+
357+
np_file = os.path.join(path, "train_32x32.npz")
358+
if file_exists(np_file) is False:
359+
filename = "train_32x32.mat"
360+
filepath = maybe_download_and_extract(filename, path, url)
361+
mat = scipy.io.loadmat(filepath)
362+
X_train = mat['X'] / 255.0 # to [0, 1]
363+
X_train = np.transpose(X_train, (3, 0, 1, 2))
364+
y_train = np.squeeze(mat['y'], axis=1)
365+
y_train[y_train == 10] = 0 # replace 10 to 0
366+
np.savez(np_file, X=X_train, y=y_train)
367+
del_file(filepath)
368+
else:
369+
v = np.load(np_file)
370+
X_train = v['X']
371+
y_train = v['y']
372+
logging.info(" n_train: {}".format(len(y_train)))
373+
374+
np_file = os.path.join(path, "test_32x32.npz")
375+
if file_exists(np_file) is False:
376+
filename = "test_32x32.mat"
377+
filepath = maybe_download_and_extract(filename, path, url)
378+
mat = scipy.io.loadmat(filepath)
379+
X_test = mat['X'] / 255.0
380+
X_test = np.transpose(X_test, (3, 0, 1, 2))
381+
y_test = np.squeeze(mat['y'], axis=1)
382+
y_test[y_test == 10] = 0
383+
np.savez(np_file, X=X_test, y=y_test)
384+
del_file(filepath)
385+
else:
386+
v = np.load(np_file)
387+
X_test = v['X']
388+
y_test = v['y']
389+
logging.info(" n_test: {}".format(len(y_test)))
390+
391+
if include_extra:
392+
logging.info(" getting extra 531131 images, please wait ...")
393+
np_file = os.path.join(path, "extra_32x32.npz")
394+
if file_exists(np_file) is False:
395+
logging.info(" the first time to load extra images will take long time to convert the file format ...")
396+
filename = "extra_32x32.mat"
397+
filepath = maybe_download_and_extract(filename, path, url)
398+
mat = scipy.io.loadmat(filepath)
399+
X_extra = mat['X'] / 255.0
400+
X_extra = np.transpose(X_extra, (3, 0, 1, 2))
401+
y_extra = np.squeeze(mat['y'], axis=1)
402+
y_extra[y_extra == 10] = 0
403+
np.savez(np_file, X=X_extra, y=y_extra)
404+
del_file(filepath)
405+
else:
406+
v = np.load(np_file)
407+
X_extra = v['X']
408+
y_extra = v['y']
409+
# print(X_train.shape, X_extra.shape)
410+
logging.info(" adding n_extra {} to n_train {}".format(len(y_extra), len(y_train)))
411+
t = time.time()
412+
X_train = np.concatenate((X_train, X_extra), 0)
413+
y_train = np.concatenate((y_train, y_extra), 0)
414+
# X_train = np.append(X_train, X_extra, axis=0)
415+
# y_train = np.append(y_train, y_extra, axis=0)
416+
logging.info(" added n_extra {} to n_train {} took {}s".format(len(y_extra), len(y_train), time.time() - t))
417+
else:
418+
logging.info(" no extra images are included")
419+
logging.info(" image size:%s n_train:%d n_test:%d" % (str(X_train.shape[1:4]), len(y_train), len(y_test)))
420+
logging.info(" took: {}s".format(int(time.time() - start_time)))
421+
return X_train, y_train, X_test, y_test
422+
423+
323424
def load_ptb_dataset(path='data'):
324425
"""Load Penn TreeBank (PTB) dataset.
325426
@@ -656,19 +757,19 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
656757
url = 'http://press.liacs.nl/mirflickr/mirflickr25k/'
657758

658759
# download dataset
659-
if folder_exists(path + "/mirflickr") is False:
760+
if folder_exists(os.path.join(path, "mirflickr")) is False:
660761
logging.info("[*] Flickr25k is nonexistent in {}".format(path))
661762
maybe_download_and_extract(filename, path, url, extract=True)
662-
del_file(path + '/' + filename)
763+
del_file(os.path.join(path, filename))
663764

664765
# return images by the given tag.
665766
# 1. image path list
666-
folder_imgs = path + "/mirflickr"
767+
folder_imgs = os.path.join(path, "mirflickr")
667768
path_imgs = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False)
668769
path_imgs.sort(key=natural_keys)
669770

670771
# 2. tag path list
671-
folder_tags = path + "/mirflickr/meta/tags"
772+
folder_tags = os.path.join(path, "mirflickr", "meta", "tags")
672773
path_tags = load_file_list(path=folder_tags, regx='\\.txt', printable=False)
673774
path_tags.sort(key=natural_keys)
674775

@@ -679,7 +780,7 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
679780
logging.info("[Flickr25k] reading images with tag: {}".format(tag))
680781
images_list = []
681782
for idx, _v in enumerate(path_tags):
682-
tags = read_file(folder_tags + '/' + path_tags[idx]).split('\n')
783+
tags = read_file(os.path.join(folder_tags, path_tags[idx])).split('\n')
683784
# logging.info(idx+1, tags)
684785
if tag is None or tag in tags:
685786
images_list.append(path_imgs[idx])
@@ -722,6 +823,8 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
722823
>>> images = tl.files.load_flickr1M_dataset(tag='zebra')
723824
724825
"""
826+
import shutil
827+
725828
path = os.path.join(path, 'flickr1M')
726829
logging.info("[Flickr1M] using {}% of images = {}".format(size * 10, size * 100000))
727830
images_zip = [
@@ -734,20 +837,21 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
734837
for image_zip in images_zip[0:size]:
735838
image_folder = image_zip.split(".")[0]
736839
# logging.info(path+"/"+image_folder)
737-
if folder_exists(path + "/" + image_folder) is False:
840+
if folder_exists(os.path.join(path, image_folder)) is False:
738841
# logging.info(image_zip)
739842
logging.info("[Flickr1M] {} is missing in {}".format(image_folder, path))
740843
maybe_download_and_extract(image_zip, path, url, extract=True)
741-
del_file(path + '/' + image_zip)
742-
os.system("mv {} {}".format(path + '/images', path + '/' + image_folder))
844+
del_file(os.path.join(path, image_zip))
845+
# os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder)))
846+
shutil.move(os.path.join(path, 'images'), os.path.join(path, image_folder))
743847
else:
744848
logging.info("[Flickr1M] {} exists in {}".format(image_folder, path))
745849

746850
# download tag
747-
if folder_exists(path + "/tags") is False:
851+
if folder_exists(os.path.join(path, "tags")) is False:
748852
logging.info("[Flickr1M] tag files is nonexistent in {}".format(path))
749853
maybe_download_and_extract(tag_zip, path, url, extract=True)
750-
del_file(path + '/' + tag_zip)
854+
del_file(os.path.join(path, tag_zip))
751855
else:
752856
logging.info("[Flickr1M] tags exists in {}".format(path))
753857

@@ -761,17 +865,19 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
761865
for folder in images_folder_list[0:size * 10]:
762866
tmp = load_file_list(path=folder, regx='\\.jpg', printable=False)
763867
tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.jpg
764-
images_list.extend([folder + '/' + x for x in tmp])
868+
images_list.extend([os.path.join(folder, x) for x in tmp])
765869

766870
# 2. tag path list
767871
tag_list = []
768-
tag_folder_list = load_folder_list(path + "/tags")
769-
tag_folder_list.sort(key=lambda s: int(s.split('/')[-1])) # folder/images/ddd
872+
tag_folder_list = load_folder_list(os.path.join(path, "tags"))
873+
874+
# tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd
875+
tag_folder_list.sort(key=lambda s: int(os.path.basename(s)))
770876

771877
for folder in tag_folder_list[0:size * 10]:
772878
tmp = load_file_list(path=folder, regx='\\.txt', printable=False)
773879
tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.txt
774-
tmp = [folder + '/' + s for s in tmp]
880+
tmp = [os.path.join(folder, s) for s in tmp]
775881
tag_list += tmp
776882

777883
# 3. select images

tensorlayer/visualize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
]
2626

2727

28-
def read_image(image, path=''):
28+
def read_image(image, path='_.png'):
2929
"""Read one image.
3030
3131
Parameters
@@ -44,7 +44,7 @@ def read_image(image, path=''):
4444
return scipy.misc.imread(os.path.join(path, image))
4545

4646

47-
def read_images(img_list, path='', n_threads=10, printable=True):
47+
def read_images(img_list, path='_.png', n_threads=10, printable=True):
4848
"""Returns all images in list by given path and name of each image file.
4949
5050
Parameters
@@ -75,7 +75,7 @@ def read_images(img_list, path='', n_threads=10, printable=True):
7575
return imgs
7676

7777

78-
def save_image(image, image_path=''):
78+
def save_image(image, image_path='_temp.png'):
7979
"""Save a image.
8080
8181
Parameters
@@ -92,7 +92,7 @@ def save_image(image, image_path=''):
9292
scipy.misc.imsave(image_path, image[:, :, 0])
9393

9494

95-
def save_images(images, size, image_path=''):
95+
def save_images(images, size, image_path='_temp.png'):
9696
"""Save multiple images into one single image.
9797
9898
Parameters

0 commit comments

Comments
 (0)