@@ -1662,13 +1662,32 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1662
1662
return op .SplitToSequence (self , list_split , axis = dim )
1663
1663
1664
1664
1665
- @torch_op (("aten::clamp" , "aten::clamp.Tensor" ), trace_only = True )
1666
- def aten_clamp (self : TReal , min : Optional [TReal ] = None , max : Optional [TReal ] = None ) -> TReal :
1667
- """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
1668
- clamped = self
1665
+ @torch_op ("aten::clamp" , trace_only = True )
1666
+ def aten_clamp (self : TReal , min : Optional [float ] = None , max : Optional [float ] = None ) -> TReal :
1667
+ """clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"""
1669
1668
1670
1669
if min is None and max is None :
1671
- return clamped
1670
+ return op .Identity (self )
1671
+
1672
+ if min is not None :
1673
+ min = op .CastLike (min , self )
1674
+
1675
+ if max is not None :
1676
+ max = op .CastLike (max , self )
1677
+
1678
+ return op .Clip (self , min , max )
1679
+
1680
+
1681
+ @torch_op ("aten::clamp.Tensor" , trace_only = True )
1682
+ def aten_clamp_tensor (
1683
+ self : TReal , min : Optional [TReal ] = None , max : Optional [TReal ] = None
1684
+ ) -> TReal :
1685
+ """clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
1686
+
1687
+ if min is None and max is None :
1688
+ return op .Identity (self )
1689
+
1690
+ clamped = self
1672
1691
1673
1692
# If min is greater than max torch.clamp(..., min, max)
1674
1693
# sets all elements in input to the value of max.
@@ -1684,11 +1703,20 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] =
1684
1703
return clamped
1685
1704
1686
1705
1687
- @torch_op (("aten::clamp_max" , "aten::clamp_max.Tensor" ), trace_only = True )
1688
- def aten_clamp_max (self : TReal , max_ : TReal ) -> TReal :
1689
- """clamp_max(Tensor self, Tensor max) -> Tensor"""
1706
+ @torch_op ("aten::clamp_max" , trace_only = True )
1707
+ def aten_clamp_max (self : TReal , max_ : float ) -> TReal :
1708
+ """clamp_max(Tensor self, Scalar max) -> Tensor"""
1709
+
1710
+ # This implementation does not intend to handle when self is an empty tensor
1711
+ max_ = op .CastLike (max_ , self )
1712
+ return op .Clip (self , None , max_ )
1690
1713
1691
- # This implementation does not intent to handle when self is an empty tensor
1714
+
1715
+ @torch_op ("aten::clamp_max.Tensor" , trace_only = True )
1716
+ def aten_clamp_max_tensor (self : TReal , max_ : TReal ) -> TReal :
1717
+ """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor"""
1718
+
1719
+ # This implementation does not intend to handle when self is an empty tensor
1692
1720
max_rank = len (max_ .shape )
1693
1721
if max_rank == 0 :
1694
1722
max_ = op .CastLike (max_ , self )
@@ -1699,11 +1727,20 @@ def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
1699
1727
return result
1700
1728
1701
1729
1702
- @torch_op (("aten::clamp_min" , "aten::clamp_min.Tensor" ), trace_only = True )
1703
- def aten_clamp_min (self : TReal , min_ : TReal ) -> TReal :
1704
- """clamp_min(Tensor self, Tensor min) -> Tensor"""
1730
+ @torch_op ("aten::clamp_min" , trace_only = True )
1731
+ def aten_clamp_min (self : TReal , min_ : float ) -> TReal :
1732
+ """clamp_min(Tensor self, Scalar min) -> Tensor"""
1733
+
1734
+ # This implementation does not intend to handle when self is an empty tensor
1735
+ min_ = op .CastLike (min_ , self )
1736
+ return op .Clip (self , min_ , None )
1737
+
1738
+
1739
+ @torch_op ("aten::clamp_min.Tensor" , trace_only = True )
1740
+ def aten_clamp_min_tensor (self : TReal , min_ : TReal ) -> TReal :
1741
+ """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor"""
1705
1742
1706
- # This implementation does not intent to handle when self is an empty tensor
1743
+ # This implementation does not intend to handle when self is an empty tensor
1707
1744
min_rank = len (min_ .shape )
1708
1745
if min_rank == 0 :
1709
1746
min_ = op .CastLike (min_ , self )
0 commit comments