@@ -2046,7 +2046,7 @@ def aten_convolution_overrideable(
2046
2046
raise NotImplementedError ()
2047
2047
2048
2048
2049
- @torch_op ("aten::copy" )
2049
+ @torch_op (( "aten::copy" , "aten::_to_copy" ) )
2050
2050
def aten_copy (
2051
2051
self : TTensor , src : TTensor , non_blocking : bool = False # pylint: disable=unused-argument
2052
2052
) -> TTensor :
@@ -5456,6 +5456,20 @@ def aten__native_batch_norm_no_training(
5456
5456
)
5457
5457
5458
5458
5459
+ @torch_op ("aten::_native_batch_norm_legit.no_stats" , trace_only = True )
5460
+ def aten__native_batch_norm_no_stats (
5461
+ input : TFloat ,
5462
+ weight : Optional [TFloat ] = None ,
5463
+ bias : Optional [TFloat ] = None ,
5464
+ training : bool = False ,
5465
+ momentum : float = 0.9 ,
5466
+ eps : float = 1e-05 ,
5467
+ ) -> Tuple [TFloat , TFloat , TFloat ]:
5468
+ """_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""
5469
+
5470
+ return aten_native_batch_norm (input , weight , bias , None , None , training , momentum , eps )
5471
+
5472
+
5459
5473
@torch_op (("aten::native_batch_norm" , "aten::_native_batch_norm_legit" ), trace_only = True )
5460
5474
def aten_native_batch_norm (
5461
5475
input : TFloat ,
@@ -5556,12 +5570,131 @@ def _aten_native_batch_norm_inference_onnx(
5556
5570
momentum = momentum ,
5557
5571
training_mode = training ,
5558
5572
)
5573
+ # NOTE: mean and var are omitted in inference mode
5559
5574
# Cannot return 2 dup output, so have to do twice with different variable name
5560
- empty_mean = op .Cast (op .Shape (input , start = 0 , end = 0 ), to = FLOAT . dtype )
5561
- empty_var = op .Cast (op .Shape (input , start = 0 , end = 0 ), to = FLOAT . dtype )
5575
+ empty_mean = op .CastLike (op .Shape (input , start = 0 , end = 0 ), norm )
5576
+ empty_var = op .CastLike (op .Shape (input , start = 0 , end = 0 ), norm )
5562
5577
return norm , empty_mean , empty_var
5563
5578
5564
5579
5580
+ # TODO: This op is using duplicated code from aten_native_batch_norm,
5581
+ # need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125
5582
+ # NOTE: This op is invoked by PyTorch Functionalization, and not in
5583
+ # native_functions.yaml, It can be found in torch/_decomp/decompositions.py
5584
+ @torch_op ("aten::_native_batch_norm_legit_functional" , trace_only = True )
5585
+ def aten__native_batch_norm_legit_functional (
5586
+ input : TFloat ,
5587
+ weight : Optional [TFloat ] = None ,
5588
+ bias : Optional [TFloat ] = None ,
5589
+ running_mean : Optional [TFloat ] = None ,
5590
+ running_var : Optional [TFloat ] = None ,
5591
+ training : bool = False ,
5592
+ momentum : float = 0.9 ,
5593
+ eps : float = 1e-05 ,
5594
+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat , TFloat ]:
5595
+ if weight is None : # Set to 1.0 as default
5596
+ weight = op .Expand (op .Constant (value_floats = [1.0 ]), op .Shape (input , start = 1 , end = 2 ))
5597
+
5598
+ if bias is None : # Set to 0.0 as default
5599
+ bias = op .Expand (op .Constant (value_floats = [0.0 ]), op .Shape (input , start = 1 , end = 2 ))
5600
+
5601
+ axes = list (range (len (input .shape )))
5602
+ axes .pop (1 )
5603
+ axes = op .Constant (value_ints = axes )
5604
+ if running_mean is None : # Using input mean
5605
+ running_mean = op .Squeeze (op .ReduceMean (input , axes ))
5606
+
5607
+ if running_var is None : # Using input var
5608
+ mean = op .ReduceMean (input , axes )
5609
+ input_sub_mean = op .Sub (input , mean )
5610
+ sqr_input_sub_mean = op .Mul (input_sub_mean , input_sub_mean )
5611
+ running_var = op .Squeeze (op .ReduceMean (sqr_input_sub_mean , axes ))
5612
+
5613
+ # Have to split to 2 private functions, because training_function return 3 outputs
5614
+ # While inference_function return 1 output
5615
+ if training is True :
5616
+ norm , mean , var , new_mean , new_var = _aten__native_batch_norm_training_functional_onnx (
5617
+ input , weight , bias , running_mean , running_var , axes , training , momentum , eps
5618
+ )
5619
+ else :
5620
+ (
5621
+ norm ,
5622
+ mean ,
5623
+ var ,
5624
+ new_mean ,
5625
+ new_var ,
5626
+ ) = _aten__native_batch_norm_inference_functional_onnx (
5627
+ input , weight , bias , running_mean , running_var , training , momentum , eps
5628
+ )
5629
+ return norm , mean , var , new_mean , new_var
5630
+
5631
+
5632
+ @torch_op ("aten::_native_batch_norm_legit_functional" , private = True )
5633
+ def _aten__native_batch_norm_training_functional_onnx (
5634
+ input : TFloat ,
5635
+ weight : TFloat ,
5636
+ bias : TFloat ,
5637
+ running_mean : TFloat ,
5638
+ running_var : TFloat ,
5639
+ axes : INT64 ,
5640
+ training : bool ,
5641
+ momentum : float ,
5642
+ eps : float ,
5643
+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat , TFloat ]:
5644
+ # Assert(training is True)
5645
+ norm , running_mean , running_var = op .BatchNormalization (
5646
+ input ,
5647
+ weight ,
5648
+ bias ,
5649
+ running_mean ,
5650
+ running_var ,
5651
+ epsilon = eps ,
5652
+ momentum = momentum ,
5653
+ training_mode = training ,
5654
+ )
5655
+ # Compute var and rstd
5656
+ mean = op .ReduceMean (input , axes )
5657
+ input_sub_mean = op .Sub (input , mean )
5658
+ sqr = op .Mul (input_sub_mean , input_sub_mean )
5659
+ var = op .ReduceMean (sqr , axes , keepdims = False )
5660
+ rstd = op .Div (1.0 , op .Sqrt (var + eps ))
5661
+ # Get mean again with size = [1, C]
5662
+ mean = op .ReduceMean (input , axes , keepdims = False )
5663
+ # NOTE: Fixed to be FLOAT dtype
5664
+ running_mean = op .Cast (running_mean , to = FLOAT .dtype )
5665
+ running_var = op .Cast (running_var , to = FLOAT .dtype )
5666
+ return norm , mean , rstd , running_mean , running_var
5667
+
5668
+
5669
+ @torch_op ("aten::_native_batch_norm_legit_functional" , private = True )
5670
+ def _aten__native_batch_norm_inference_functional_onnx (
5671
+ input : TFloat ,
5672
+ weight : TFloat ,
5673
+ bias : TFloat ,
5674
+ running_mean : TFloat ,
5675
+ running_var : TFloat ,
5676
+ training : bool ,
5677
+ momentum : float ,
5678
+ eps : float ,
5679
+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat , TFloat ]:
5680
+ # Assert(training is False)
5681
+ norm = op .BatchNormalization (
5682
+ input ,
5683
+ weight ,
5684
+ bias ,
5685
+ running_mean ,
5686
+ running_var ,
5687
+ epsilon = eps ,
5688
+ momentum = momentum ,
5689
+ training_mode = training ,
5690
+ )
5691
+ # NOTE: mean and var are ommited in inference mode
5692
+ # Cannot return 2 dup output, so have to do twice with different variable name
5693
+ empty_mean = op .CastLike (op .Shape (input , start = 0 , end = 0 ), norm )
5694
+ empty_var = op .CastLike (op .Shape (input , start = 0 , end = 0 ), norm )
5695
+ return norm , empty_mean , empty_var , running_mean , running_var
5696
+
5697
+
5565
5698
def aten_native_batch_norm_backward (
5566
5699
grad_out : TensorType ,
5567
5700
input : TensorType ,
0 commit comments