@@ -765,35 +765,42 @@ def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
765
765
return clamped
766
766
767
767
768
- @torch_op ("aten::clamp_max.Scalar" , overload = True )
769
- def aten_clamp_max_scalar (self , max_ ):
770
- # clamp_max(Tensor self, Scalar max) -> Tensor
771
-
772
- max_ = op .CastLike (max_ , self )
773
- return op .Clip (self , None , max_ )
774
-
775
-
776
- @torch_op ("aten::clamp_max.Tensor" )
777
- def aten_clamp_max_tensor (self , max_ ):
768
+ @torch_op ("aten::clamp_max" )
769
+ def aten_clamp_max (self , max_ ):
778
770
# clamp_max(Tensor self, Tensor max) -> Tensor
779
771
780
- return op .Min (self , max_ )
781
-
772
+ self_size = op .Size (self )
773
+ max_shape = op .Shape (max_ )
774
+ max_rank = op .Size (max_shape )
775
+ if self_size == 0 :
776
+ result = op .Expand (self , max_shape )
777
+ else :
778
+ if max_rank == 0 :
779
+ max_ = op .CastLike (max_ , self )
780
+ result = op .Clip (self , None , max_ )
781
+ else :
782
+ result = op .Min (self , max_ )
782
783
783
- @torch_op ("aten::clamp_min.Scalar" , overload = True )
784
- def aten_clamp_min_scalar (self , min_ ):
785
- # clamp_min(Tensor self, Scalar min) -> Tensor
786
- # NOTE: min_ is a rank 0 tensor.
787
- # TODO(justinchuby): Specify the type constraints.
788
- min_ = op .CastLike (min_ , self )
789
- return op .Clip (self , min_ , None )
784
+ return result
790
785
791
786
792
- @torch_op ("aten::clamp_min.Tensor " )
793
- def aten_clamp_min_tensor (self , min_ ):
787
+ @torch_op ("aten::clamp_min" )
788
+ def aten_clamp_min (self , min_ ):
794
789
# clamp_min(Tensor self, Tensor min) -> Tensor
795
- # TODO(justinchuby): Specify the type constraints.
796
- return op .Max (self , min_ )
790
+
791
+ self_size = op .Size (self )
792
+ min_shape = op .Shape (min_ )
793
+ min_rank = op .Size (min_shape )
794
+ if self_size == 0 :
795
+ result = op .Expand (self , min_shape )
796
+ else :
797
+ if min_rank == 0 :
798
+ min_ = op .CastLike (min_ , self )
799
+ result = op .Clip (self , min_ , None )
800
+ else :
801
+ result = op .Max (self , min_ )
802
+
803
+ return result
797
804
798
805
799
806
def aten_clone (self : TensorType , memory_format : Optional [str ] = None ) -> TensorType :
@@ -3976,16 +3983,18 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
3976
3983
def aten_repeat (self , repeats : INT64 ):
3977
3984
# repeat(Tensor self, SymInt[] repeats) -> Tensor
3978
3985
3979
- # FIXME(justinchuby): When repeats.shape == [0]
3980
-
3981
- # TODO(justinchuby): Make ones_like a function when onnxscript supports it
3982
- # shape = ones_like(repeats) := {
3983
- one = op .Constant (value_int = 1 )
3984
- repeats_shape = op .Shape (repeats )
3985
- shape = op .Expand (one , repeats_shape )
3986
- # }
3987
- self_expanded = op .Expand (self , shape ) # type: ignore[arg-type]
3988
- return op .Tile (self_expanded , repeats )
3986
+ if op .Size (repeats ) == 0 :
3987
+ result = self
3988
+ else :
3989
+ # TODO(justinchuby): Make ones_like a function when onnxscript supports it
3990
+ # shape = ones_like(repeats) := {
3991
+ one = op .Constant (value_int = 1 )
3992
+ repeats_shape = op .Shape (repeats )
3993
+ shape = op .Expand (one , repeats_shape )
3994
+ # }
3995
+ self_expanded = op .Expand (self , shape ) # type: ignore[arg-type]
3996
+ result = op .Tile (self_expanded , repeats )
3997
+ return result
3989
3998
3990
3999
3991
4000
def aten_repeat_interleave (
0 commit comments