@@ -79,7 +79,7 @@ def _get_cache_path(filepath):
79
79
return cache_path
80
80
81
81
82
- def load_data (traindir , valdir , cache_dataset , distributed ):
82
+ def load_data (traindir , valdir , args ):
83
83
# Data loading code
84
84
print ("Loading data" )
85
85
normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
@@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
88
88
print ("Loading training data" )
89
89
st = time .time ()
90
90
cache_path = _get_cache_path (traindir )
91
- if cache_dataset and os .path .exists (cache_path ):
91
+ if args . cache_dataset and os .path .exists (cache_path ):
92
92
# Attention, as the transforms are also cached!
93
93
print ("Loading dataset_train from {}" .format (cache_path ))
94
94
dataset , _ = torch .load (cache_path )
95
95
else :
96
+ trans = [
97
+ transforms .RandomResizedCrop (224 ),
98
+ transforms .RandomHorizontalFlip (),
99
+ ]
100
+ if args .auto_augment is not None :
101
+ aa_policy = transforms .AutoAugmentPolicy (args .auto_augment )
102
+ trans .append (transforms .AutoAugment (policy = aa_policy ))
103
+ trans .extend ([
104
+ transforms .ToTensor (),
105
+ normalize ,
106
+ ])
107
+ if args .random_erase > 0 :
108
+ trans .append (transforms .RandomErasing (p = args .random_erase ))
96
109
dataset = torchvision .datasets .ImageFolder (
97
110
traindir ,
98
- transforms .Compose ([
99
- transforms .RandomResizedCrop (224 ),
100
- transforms .RandomHorizontalFlip (),
101
- transforms .ToTensor (),
102
- normalize ,
103
- ]))
104
- if cache_dataset :
111
+ transforms .Compose (trans ))
112
+ if args .cache_dataset :
105
113
print ("Saving dataset_train to {}" .format (cache_path ))
106
114
utils .mkdir (os .path .dirname (cache_path ))
107
115
utils .save_on_master ((dataset , traindir ), cache_path )
108
116
print ("Took" , time .time () - st )
109
117
110
118
print ("Loading validation data" )
111
119
cache_path = _get_cache_path (valdir )
112
- if cache_dataset and os .path .exists (cache_path ):
120
+ if args . cache_dataset and os .path .exists (cache_path ):
113
121
# Attention, as the transforms are also cached!
114
122
print ("Loading dataset_test from {}" .format (cache_path ))
115
123
dataset_test , _ = torch .load (cache_path )
@@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
122
130
transforms .ToTensor (),
123
131
normalize ,
124
132
]))
125
- if cache_dataset :
133
+ if args . cache_dataset :
126
134
print ("Saving dataset_test to {}" .format (cache_path ))
127
135
utils .mkdir (os .path .dirname (cache_path ))
128
136
utils .save_on_master ((dataset_test , valdir ), cache_path )
129
137
130
138
print ("Creating data loaders" )
131
- if distributed :
139
+ if args . distributed :
132
140
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset )
133
141
test_sampler = torch .utils .data .distributed .DistributedSampler (dataset_test )
134
142
else :
@@ -155,8 +163,7 @@ def main(args):
155
163
156
164
train_dir = os .path .join (args .data_path , 'train' )
157
165
val_dir = os .path .join (args .data_path , 'val' )
158
- dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir ,
159
- args .cache_dataset , args .distributed )
166
+ dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir , args )
160
167
data_loader = torch .utils .data .DataLoader (
161
168
dataset , batch_size = args .batch_size ,
162
169
sampler = train_sampler , num_workers = args .workers , pin_memory = True )
@@ -173,8 +180,15 @@ def main(args):
173
180
174
181
criterion = nn .CrossEntropyLoss ()
175
182
176
- optimizer = torch .optim .SGD (
177
- model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
183
+ opt_name = args .opt .lower ()
184
+ if opt_name == 'sgd' :
185
+ optimizer = torch .optim .SGD (
186
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
187
+ elif opt_name == 'rmsprop' :
188
+ optimizer = torch .optim .RMSprop (model .parameters (), lr = args .lr , momentum = args .momentum ,
189
+ weight_decay = args .weight_decay , eps = 0.0316 , alpha = 0.9 )
190
+ else :
191
+ raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." .format (args .opt ))
178
192
179
193
if args .apex :
180
194
model , optimizer = amp .initialize (model , optimizer ,
@@ -238,6 +252,7 @@ def parse_args():
238
252
help = 'number of total epochs to run' )
239
253
parser .add_argument ('-j' , '--workers' , default = 16 , type = int , metavar = 'N' ,
240
254
help = 'number of data loading workers (default: 16)' )
255
+ parser .add_argument ('--opt' , default = 'sgd' , type = str , help = 'optimizer' )
241
256
parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'initial learning rate' )
242
257
parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
243
258
help = 'momentum' )
@@ -275,6 +290,8 @@ def parse_args():
275
290
help = "Use pre-trained models from the modelzoo" ,
276
291
action = "store_true" ,
277
292
)
293
+ parser .add_argument ('--auto-augment' , default = None , help = 'auto augment policy (default: None)' )
294
+ parser .add_argument ('--random-erase' , default = 0.0 , type = float , help = 'random erasing probability (default: 0.0)' )
278
295
279
296
# Mixed precision training parameters
280
297
parser .add_argument ('--apex' , action = 'store_true' ,
0 commit comments