@@ -384,13 +384,48 @@ def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5):
384
384
def runAtenTest (self , tensors , fn , device = None , rel_err = 1e-2 , abs_err = 1e-5 ):
385
385
if device is None :
386
386
device = xm .xla_device ()
387
+
388
+ def to_device (tensors ):
389
+ return [
390
+ x .to (device ).clone ().detach ().requires_grad_ (x .requires_grad )
391
+ if isinstance (x , torch .Tensor ) else x
392
+ for x in tensors
393
+ ]
394
+
395
+ orig = to_device (xu .as_list (tensors ))
396
+
387
397
tensors = xu .as_list (tensors )
388
- xla_tensors = [
389
- x .to (device ).detach ().requires_grad_ (x .requires_grad ) for x in tensors
390
- ]
398
+ xla_tensors = to_device (tensors )
391
399
results = xu .as_list (fn (* tensors ))
392
400
xla_results = xu .as_list (fn (* xla_tensors ))
393
- self .compareResults (results , xla_results , rel_err = rel_err , abs_err = abs_err )
401
+
402
+ try :
403
+ self .compareResults (results , xla_results , rel_err = rel_err , abs_err = abs_err )
404
+ except :
405
+ xla_tensors = to_device (orig )
406
+
407
+ import torch_xla .debug .metrics as met
408
+ met .clear_all ()
409
+ xla_results = xu .as_list (fn (* xla_tensors ))
410
+ print ("++++++++++++++++++++++ Fallback" )
411
+ print (met .executed_fallback_ops ())
412
+
413
+ print ("++++++++++++++++++++++ CPU Tensors" )
414
+ print ("++++++++++++++++++++++++++++++++++ Input" )
415
+ for t in tensors :
416
+ print (t )
417
+ print ("++++++++++++++++++++++++++++++++++ Output" )
418
+ for r in results :
419
+ print (r )
420
+ print ("++++++++++++++++++++++ XLA Tensors" )
421
+ print ("++++++++++++++++++++++++++++++++++ Input" )
422
+ for t in xla_tensors :
423
+ print (t )
424
+ print ("++++++++++++++++++++++++++++++++++ Output" )
425
+ for r in xla_results :
426
+ print (r )
427
+
428
+ raise
394
429
395
430
396
431
@contextmanager
0 commit comments