Skip to content

Commit dfaf174

Browse files
authored
add ops: ge, le, maxmium, minmium, softmax (#313)
1 parent b29a9fe commit dfaf174

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,10 +1980,11 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
19801980
raise NotImplementedError()
19811981

19821982

1983-
def aten_ge(self: TensorType, other: TensorType) -> TensorType:
1983+
@torch_op("aten::ge")
1984+
def aten_ge(self: TReal, other: TReal) -> BOOL:
19841985
# ge.Tensor(Tensor self, Tensor other) -> Tensor
19851986

1986-
raise NotImplementedError()
1987+
return op.Greater(self, other)
19871988

19881989

19891990
def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
@@ -2537,10 +2538,11 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
25372538
raise NotImplementedError()
25382539

25392540

2540-
def aten_le(self: TensorType, other: TensorType) -> TensorType:
2541+
@torch_op("aten::le")
2542+
def aten_le(self: TReal, other: TReal) -> BOOL:
25412543
# le.Tensor(Tensor self, Tensor other) -> Tensor
25422544

2543-
raise NotImplementedError()
2545+
return op.Less(self, other)
25442546

25452547

25462548
def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
@@ -2925,10 +2927,11 @@ def aten_max_pool3d(
29252927
raise NotImplementedError()
29262928

29272929

2928-
def aten_maximum(self: TensorType, other: TensorType) -> TensorType:
2930+
@torch_op("aten::maximum")
2931+
def aten_maximum(self: TReal, other: TReal) -> TReal:
29292932
# maximum(Tensor self, Tensor other) -> Tensor
29302933

2931-
raise NotImplementedError()
2934+
return op.Max(self, other)
29322935

29332936

29342937
def aten_mean(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@@ -2955,10 +2958,11 @@ def aten_min(self: TensorType) -> TensorType:
29552958
raise NotImplementedError()
29562959

29572960

2958-
def aten_minimum(self: TensorType, other: TensorType) -> TensorType:
2961+
@torch_op("aten::minimum")
2962+
def aten_minimum(self: TReal, other: TReal) -> TReal:
29592963
# minimum(Tensor self, Tensor other) -> Tensor
29602964

2961-
raise NotImplementedError()
2965+
return op.Min(self, other)
29622966

29632967

29642968
def aten_miopen_batch_norm(

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,21 @@ def aten_special_sinc(self: TensorType) -> TensorType:
338338
raise NotImplementedError()
339339

340340

341+
@torch_op("aten::softmax")
341342
def aten_special_softmax(
342-
self: TensorType, dim: int, dtype: Optional[int] = None
343-
) -> TensorType:
343+
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
344+
) -> TFloatOrBFloat16:
344345
# special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
345346

346-
raise NotImplementedError()
347+
self_is_scalar = op.Size(op.Shape(self)) == 0
348+
if self_is_scalar:
349+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
350+
result = op.Softmax(self, axis=dim)
351+
result = op.Cast(result, to=dtype)
352+
if self_is_scalar: # squeeze to scalar due to input is scalar
353+
result = op.Squeeze(result)
354+
355+
return result
347356

348357

349358
def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ def _log_softmax_input_wrangler(
205205
return args, kwargs
206206

207207

208+
def _softmax_input_wrangler(
209+
args: list[Any], kwargs: dict[str, Any]
210+
) -> tuple[list[Any], dict[str, Any]]:
211+
kwargs["dim"] = args.pop()
212+
return args, kwargs
213+
214+
208215
def _topk_input_wrangler(
209216
args: list[Any], kwargs: dict[str, Any]
210217
) -> tuple[list[Any], dict[str, Any]]:
@@ -267,9 +274,11 @@ def _topk_input_wrangler(
267274
"fmod": core_ops.aten_fmod,
268275
"full": (core_ops.aten_full, _full_input_wrangler),
269276
"full_like": core_ops.aten_full_like,
277+
"ge": core_ops.aten_ge,
270278
"gt": core_ops.aten_gt,
271279
"isinf": core_ops.aten_isinf,
272280
"log": core_ops.aten_log,
281+
"le": core_ops.aten_le,
273282
"log10": core_ops.aten_log10,
274283
"log1p": core_ops.aten_log1p,
275284
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
@@ -281,6 +290,8 @@ def _topk_input_wrangler(
281290
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
282291
"lt": core_ops.aten_lt,
283292
"matmul": core_ops.aten_matmul,
293+
"maximum": core_ops.aten_maximum,
294+
"minimum": core_ops.aten_minimum,
284295
"mm": core_ops.aten_mm,
285296
"mul": core_ops.aten_mul,
286297
"ne": core_ops.aten_ne,
@@ -316,6 +327,7 @@ def _topk_input_wrangler(
316327
"sin": core_ops.aten_sin,
317328
"sinh": core_ops.aten_sinh,
318329
"slice": core_ops.aten_slice,
330+
"softmax": (special_ops.aten_special_softmax, _softmax_input_wrangler),
319331
"sqrt": core_ops.aten_sqrt,
320332
"sub": core_ops.aten_sub,
321333
"t": core_ops.aten_t,

0 commit comments

Comments
 (0)