@@ -2742,10 +2742,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
2742
2742
(
2743
2743
"aten::div.Tensor" ,
2744
2744
"aten::div.Scalar" ,
2745
- # When rounding_mode is None, performs a true division
2746
- # https://pytorch.org/docs/stable/generated/torch.div.html
2747
- "aten::div.Tensor_mode" ,
2748
- "aten::div.Scalar_mode" ,
2749
2745
"aten::divide.Tensor" ,
2750
2746
"aten::divide.Scalar" ,
2751
2747
"aten::true_divide.Tensor" ,
@@ -2799,41 +2795,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
2799
2795
2800
2796
2801
2797
@torch_op (("aten::div.Tensor_mode" , "aten::div.Scalar_mode" ), trace_only = True )
2802
- def aten_div_mode (self : TFloat , other : TFloat , rounding_mode : str ) -> TFloat :
2798
+ def aten_div_mode (self : TFloat , other : TFloat , rounding_mode : Optional [ str ] = None ) -> TFloat :
2803
2799
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
2804
2800
2805
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2806
- assert rounding_mode in {"trunc" , "floor" }
2801
+ assert rounding_mode in {"trunc" , "floor" , None }
2807
2802
2808
2803
if rounding_mode == "trunc" :
2809
2804
# Rounds the results of the division towards zero.
2810
2805
# Equivalent to C-style integer division
2811
- result = aten_trunc (op .Div (self , other ))
2812
- else : # rounding_mode == "floor"
2813
- result = op .Floor (op .Div (self , other ))
2806
+ return aten_trunc (op .Div (self , other ))
2807
+ if rounding_mode == "floor" :
2808
+ return op .Floor (op .Div (self , other ))
2814
2809
2815
- return result
2810
+ return op . Div ( self , other )
2816
2811
2817
2812
2818
2813
@torch_op (("aten::div.Tensor_mode" , "aten::div.Scalar_mode" ), trace_only = True )
2819
- def aten_div_mode_int (self : TInt , other : TInt , rounding_mode : str ) -> TInt :
2814
+ def aten_div_mode_int (
2815
+ self : TInt , other : TInt , rounding_mode : Optional [str ] = None
2816
+ ) -> TensorType :
2820
2817
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
2821
2818
2822
2819
Variant for integer inputs.
2823
2820
"""
2824
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2825
- assert rounding_mode in {"trunc" , "floor" }
2821
+ assert rounding_mode in {"trunc" , "floor" , None }
2826
2822
2827
2823
quotient = op .Div (op .Cast (self , to = FLOAT .dtype ), op .Cast (other , to = FLOAT .dtype ))
2828
2824
2829
2825
if rounding_mode == "trunc" :
2830
2826
# Rounds the results of the division towards zero.
2831
2827
# Equivalent to C-style integer division
2832
2828
result = aten_trunc (quotient )
2833
- else : # rounding_mode == "floor"
2829
+ return op .CastLike (result , self )
2830
+ if rounding_mode == "floor" :
2834
2831
result = op .Floor (quotient )
2832
+ return op .CastLike (result , self )
2835
2833
2836
- return op .CastLike (result , self )
2834
+ assert rounding_mode is None
2835
+ # When rounding_mode is None, the return type is float32
2836
+ return quotient
2837
2837
2838
2838
2839
2839
@torch_op ("aten::dot" )
@@ -8465,7 +8465,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
8465
8465
raise NotImplementedError ()
8466
8466
8467
8467
8468
- @torch_op ("aten::trunc" )
8468
+ @torch_op ("aten::trunc" , trace_only = True )
8469
8469
def aten_trunc (self : TFloat ) -> TFloat :
8470
8470
"""trunc(Tensor self) -> Tensor"""
8471
8471
0 commit comments