Skip to content

Commit d4bbee7

Browse files
authored
[torchlib] Update operator:pow implementation (#2069)
Register it to aten_pow instead because exponent may not be a tensor. Fixes pytorch/pytorch#147606
1 parent 013d28c commit d4bbee7

File tree

1 file changed

+8
-9
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+8
-9
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6651,22 +6651,21 @@ def aten_positive(self: TensorType) -> TensorType:
66516651
raise NotImplementedError()
66526652

66536653

6654-
@torch_op(
6655-
("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"),
6656-
trace_only=True,
6657-
)
6654+
@torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True)
66586655
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
66596656
"""pow(Tensor self, Tensor exponent) -> Tensor"""
66606657
return op.Pow(self, exponent)
66616658

66626659

6663-
@torch_op(
6664-
("_operator::pow", "aten::pow.Scalar"),
6665-
trace_only=True,
6666-
)
6660+
@torch_op("aten::pow.Tensor_Scalar", trace_only=True)
6661+
def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal:
6662+
"""pow(Tensor self, Scalar exponent) -> Tensor"""
6663+
return op.Pow(self, exponent)
6664+
6665+
6666+
@torch_op("aten::pow.Scalar", trace_only=True)
66676667
def aten_pow_scalar(self: float, exponent: TTensor) -> TTensor:
66686668
"""pow.Scalar(Scalar self, Tensor exponent) -> Tensor"""
6669-
66706669
return op.Pow(op.Cast(self, to=exponent.dtype), exponent)
66716670

66726671

0 commit comments

Comments
 (0)