@@ -247,32 +247,39 @@ def main():
247
247
help = 'learning rate (default: 1.0)' )
248
248
parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
249
249
help = 'Learning rate step gamma (default: 0.7)' )
250
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
250
+ parser .add_argument ('--no-cuda' , action = 'store_true' ,
251
251
help = 'disables CUDA training' )
252
- parser .add_argument ('--no-mps' , action = 'store_true' , default = False ,
252
+ parser .add_argument ('--no-xpu' , action = 'store_true' ,
253
+ help = 'disables XPU training' )
254
+ parser .add_argument ('--no-mps' , action = 'store_true' ,
253
255
help = 'disables macOS GPU training' )
254
- parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
256
+ parser .add_argument ('--dry-run' , action = 'store_true' ,
255
257
help = 'quickly check a single pass' )
256
258
parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
257
259
help = 'random seed (default: 1)' )
258
260
parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
259
261
help = 'how many batches to wait before logging training status' )
260
- parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
262
+ parser .add_argument ('--save-model' , action = 'store_true' ,
261
263
help = 'For Saving the current Model' )
262
264
args = parser .parse_args ()
263
265
264
266
use_cuda = not args .no_cuda and torch .cuda .is_available ()
267
+ use_xpu = not args .no_xpu and torch .xpu .is_available ()
265
268
use_mps = not args .no_mps and torch .backends .mps .is_available ()
266
269
267
270
torch .manual_seed (args .seed )
268
271
269
272
if use_cuda :
270
273
device = torch .device ("cuda" )
274
+ elif use_xpu :
275
+ device = torch .device ("xpu" )
271
276
elif use_mps :
272
277
device = torch .device ("mps" )
273
278
else :
274
279
device = torch .device ("cpu" )
275
280
281
+ print ('Device to use: ' , device )
282
+
276
283
train_kwargs = {'batch_size' : args .batch_size }
277
284
test_kwargs = {'batch_size' : args .test_batch_size }
278
285
if use_cuda :
0 commit comments