39
39
import sys
40
40
import tarfile
41
41
import zipfile
42
+ import time
42
43
43
44
import numpy as np
44
45
import tensorflow as tf
@@ -320,6 +321,106 @@ def unpickle(file):
320
321
return X_train , y_train , X_test , y_test
321
322
322
323
324
+ def load_cropped_svhn (path = 'data' , include_extra = True ):
325
+ """Load Cropped SVHN.
326
+
327
+ The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images.
328
+ Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see `ufldl website <http://ufldl.stanford.edu/housenumbers/>`__.
329
+
330
+ Parameters
331
+ ----------
332
+ path : str
333
+ The path that the data is downloaded to.
334
+ include_extra : boolean
335
+ If True (default), add extra images to the training set.
336
+
337
+ Returns
338
+ -------
339
+ X_train, y_train, X_test, y_test: tuple
340
+ Return splitted training/test set respectively.
341
+
342
+ Examples
343
+ ---------
344
+ >>> X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)
345
+ >>> tl.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')
346
+
347
+ """
348
+
349
+ import scipy .io
350
+
351
+ start_time = time .time ()
352
+
353
+ path = os .path .join (path , 'cropped_svhn' )
354
+ logging .info ("Load or Download Cropped SVHN > {} | include extra images: {}" .format (path , include_extra ))
355
+ url = "http://ufldl.stanford.edu/housenumbers/"
356
+
357
+ np_file = os .path .join (path , "train_32x32.npz" )
358
+ if file_exists (np_file ) is False :
359
+ filename = "train_32x32.mat"
360
+ filepath = maybe_download_and_extract (filename , path , url )
361
+ mat = scipy .io .loadmat (filepath )
362
+ X_train = mat ['X' ] / 255.0 # to [0, 1]
363
+ X_train = np .transpose (X_train , (3 , 0 , 1 , 2 ))
364
+ y_train = np .squeeze (mat ['y' ], axis = 1 )
365
+ y_train [y_train == 10 ] = 0 # replace 10 to 0
366
+ np .savez (np_file , X = X_train , y = y_train )
367
+ del_file (filepath )
368
+ else :
369
+ v = np .load (np_file )
370
+ X_train = v ['X' ]
371
+ y_train = v ['y' ]
372
+ logging .info (" n_train: {}" .format (len (y_train )))
373
+
374
+ np_file = os .path .join (path , "test_32x32.npz" )
375
+ if file_exists (np_file ) is False :
376
+ filename = "test_32x32.mat"
377
+ filepath = maybe_download_and_extract (filename , path , url )
378
+ mat = scipy .io .loadmat (filepath )
379
+ X_test = mat ['X' ] / 255.0
380
+ X_test = np .transpose (X_test , (3 , 0 , 1 , 2 ))
381
+ y_test = np .squeeze (mat ['y' ], axis = 1 )
382
+ y_test [y_test == 10 ] = 0
383
+ np .savez (np_file , X = X_test , y = y_test )
384
+ del_file (filepath )
385
+ else :
386
+ v = np .load (np_file )
387
+ X_test = v ['X' ]
388
+ y_test = v ['y' ]
389
+ logging .info (" n_test: {}" .format (len (y_test )))
390
+
391
+ if include_extra :
392
+ logging .info (" getting extra 531131 images, please wait ..." )
393
+ np_file = os .path .join (path , "extra_32x32.npz" )
394
+ if file_exists (np_file ) is False :
395
+ logging .info (" the first time to load extra images will take long time to convert the file format ..." )
396
+ filename = "extra_32x32.mat"
397
+ filepath = maybe_download_and_extract (filename , path , url )
398
+ mat = scipy .io .loadmat (filepath )
399
+ X_extra = mat ['X' ] / 255.0
400
+ X_extra = np .transpose (X_extra , (3 , 0 , 1 , 2 ))
401
+ y_extra = np .squeeze (mat ['y' ], axis = 1 )
402
+ y_extra [y_extra == 10 ] = 0
403
+ np .savez (np_file , X = X_extra , y = y_extra )
404
+ del_file (filepath )
405
+ else :
406
+ v = np .load (np_file )
407
+ X_extra = v ['X' ]
408
+ y_extra = v ['y' ]
409
+ # print(X_train.shape, X_extra.shape)
410
+ logging .info (" adding n_extra {} to n_train {}" .format (len (y_extra ), len (y_train )))
411
+ t = time .time ()
412
+ X_train = np .concatenate ((X_train , X_extra ), 0 )
413
+ y_train = np .concatenate ((y_train , y_extra ), 0 )
414
+ # X_train = np.append(X_train, X_extra, axis=0)
415
+ # y_train = np.append(y_train, y_extra, axis=0)
416
+ logging .info (" added n_extra {} to n_train {} took {}s" .format (len (y_extra ), len (y_train ), time .time () - t ))
417
+ else :
418
+ logging .info (" no extra images are included" )
419
+ logging .info (" image size:%s n_train:%d n_test:%d" % (str (X_train .shape [1 :4 ]), len (y_train ), len (y_test )))
420
+ logging .info (" took: {}s" .format (int (time .time () - start_time )))
421
+ return X_train , y_train , X_test , y_test
422
+
423
+
323
424
def load_ptb_dataset (path = 'data' ):
324
425
"""Load Penn TreeBank (PTB) dataset.
325
426
@@ -656,19 +757,19 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
656
757
url = 'http://press.liacs.nl/mirflickr/mirflickr25k/'
657
758
658
759
# download dataset
659
- if folder_exists (path + "/ mirflickr" ) is False :
760
+ if folder_exists (os . path . join ( path , " mirflickr") ) is False :
660
761
logging .info ("[*] Flickr25k is nonexistent in {}" .format (path ))
661
762
maybe_download_and_extract (filename , path , url , extract = True )
662
- del_file (path + '/' + filename )
763
+ del_file (os . path . join ( path , filename ) )
663
764
664
765
# return images by the given tag.
665
766
# 1. image path list
666
- folder_imgs = path + "/ mirflickr"
767
+ folder_imgs = os . path . join ( path , " mirflickr")
667
768
path_imgs = load_file_list (path = folder_imgs , regx = '\\ .jpg' , printable = False )
668
769
path_imgs .sort (key = natural_keys )
669
770
670
771
# 2. tag path list
671
- folder_tags = path + "/ mirflickr/ meta/ tags"
772
+ folder_tags = os . path . join ( path , " mirflickr" , " meta" , " tags")
672
773
path_tags = load_file_list (path = folder_tags , regx = '\\ .txt' , printable = False )
673
774
path_tags .sort (key = natural_keys )
674
775
@@ -679,7 +780,7 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
679
780
logging .info ("[Flickr25k] reading images with tag: {}" .format (tag ))
680
781
images_list = []
681
782
for idx , _v in enumerate (path_tags ):
682
- tags = read_file (folder_tags + '/' + path_tags [idx ]).split ('\n ' )
783
+ tags = read_file (os . path . join ( folder_tags , path_tags [idx ]) ).split ('\n ' )
683
784
# logging.info(idx+1, tags)
684
785
if tag is None or tag in tags :
685
786
images_list .append (path_imgs [idx ])
@@ -722,6 +823,8 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
722
823
>>> images = tl.files.load_flickr1M_dataset(tag='zebra')
723
824
724
825
"""
826
+ import shutil
827
+
725
828
path = os .path .join (path , 'flickr1M' )
726
829
logging .info ("[Flickr1M] using {}% of images = {}" .format (size * 10 , size * 100000 ))
727
830
images_zip = [
@@ -734,20 +837,21 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
734
837
for image_zip in images_zip [0 :size ]:
735
838
image_folder = image_zip .split ("." )[0 ]
736
839
# logging.info(path+"/"+image_folder)
737
- if folder_exists (path + "/" + image_folder ) is False :
840
+ if folder_exists (os . path . join ( path , image_folder ) ) is False :
738
841
# logging.info(image_zip)
739
842
logging .info ("[Flickr1M] {} is missing in {}" .format (image_folder , path ))
740
843
maybe_download_and_extract (image_zip , path , url , extract = True )
741
- del_file (path + '/' + image_zip )
742
- os .system ("mv {} {}" .format (path + '/images' , path + '/' + image_folder ))
844
+ del_file (os .path .join (path , image_zip ))
845
+ # os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder)))
846
+ shutil .move (os .path .join (path , 'images' ), os .path .join (path , image_folder ))
743
847
else :
744
848
logging .info ("[Flickr1M] {} exists in {}" .format (image_folder , path ))
745
849
746
850
# download tag
747
- if folder_exists (path + "/ tags" ) is False :
851
+ if folder_exists (os . path . join ( path , " tags") ) is False :
748
852
logging .info ("[Flickr1M] tag files is nonexistent in {}" .format (path ))
749
853
maybe_download_and_extract (tag_zip , path , url , extract = True )
750
- del_file (path + '/' + tag_zip )
854
+ del_file (os . path . join ( path , tag_zip ) )
751
855
else :
752
856
logging .info ("[Flickr1M] tags exists in {}" .format (path ))
753
857
@@ -761,17 +865,19 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
761
865
for folder in images_folder_list [0 :size * 10 ]:
762
866
tmp = load_file_list (path = folder , regx = '\\ .jpg' , printable = False )
763
867
tmp .sort (key = lambda s : int (s .split ('.' )[- 2 ])) # ddd.jpg
764
- images_list .extend ([folder + '/' + x for x in tmp ])
868
+ images_list .extend ([os . path . join ( folder , x ) for x in tmp ])
765
869
766
870
# 2. tag path list
767
871
tag_list = []
768
- tag_folder_list = load_folder_list (path + "/tags" )
769
- tag_folder_list .sort (key = lambda s : int (s .split ('/' )[- 1 ])) # folder/images/ddd
872
+ tag_folder_list = load_folder_list (os .path .join (path , "tags" ))
873
+
874
+ # tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd
875
+ tag_folder_list .sort (key = lambda s : int (os .path .basename (s )))
770
876
771
877
for folder in tag_folder_list [0 :size * 10 ]:
772
878
tmp = load_file_list (path = folder , regx = '\\ .txt' , printable = False )
773
879
tmp .sort (key = lambda s : int (s .split ('.' )[- 2 ])) # ddd.txt
774
- tmp = [folder + '/' + s for s in tmp ]
880
+ tmp = [os . path . join ( folder , s ) for s in tmp ]
775
881
tag_list += tmp
776
882
777
883
# 3. select images
0 commit comments