Skip to content

Commit f18d2cc

Browse files
committed
fix type
1 parent eb3198e commit f18d2cc

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,10 +2538,11 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
25382538
raise NotImplementedError()
25392539

25402540

2541-
def aten_le(self: TensorType, other: TensorType) -> TensorType:
2541+
@torch_op("aten::le")
2542+
def aten_le(self: TReal, other: TReal) -> BOOL:
25422543
# le.Tensor(Tensor self, Tensor other) -> Tensor
25432544

2544-
raise NotImplementedError()
2545+
return op.Less(self, other)
25452546

25462547

25472548
def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
@@ -2957,10 +2958,11 @@ def aten_min(self: TensorType) -> TensorType:
29572958
raise NotImplementedError()
29582959

29592960

2960-
def aten_minimum(self: TensorType, other: TensorType) -> TensorType:
2961+
@torch_op("aten::minimum")
2962+
def aten_minimum(self: TReal, other: TReal) -> TReal:
29612963
# minimum(Tensor self, Tensor other) -> Tensor
29622964

2963-
raise NotImplementedError()
2965+
return op.Min(self, other)
29642966

29652967

29662968
def aten_miopen_batch_norm(

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def _topk_input_wrangler(
278278
"gt": core_ops.aten_gt,
279279
"isinf": core_ops.aten_isinf,
280280
"log": core_ops.aten_log,
281+
"le": core_ops.aten_le,
281282
"log10": core_ops.aten_log10,
282283
"log1p": core_ops.aten_log1p,
283284
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
@@ -290,6 +291,7 @@ def _topk_input_wrangler(
290291
"lt": core_ops.aten_lt,
291292
"matmul": core_ops.aten_matmul,
292293
"maximum": core_ops.aten_maximum,
294+
"minimum": core_ops.aten_minimum,
293295
"mm": core_ops.aten_mm,
294296
"mul": core_ops.aten_mul,
295297
"ne": core_ops.aten_ne,

0 commit comments

Comments
 (0)