Skip to content

Commit 339a895

Browse files
authored
feat(atenlib): logarithmic ops; test aten::full (#281)
Implement logarithmic ops: log log10 log1p log2 logaddexp logaddexp2 logcumsumexp logsumexp logdet log_sigmoid xlogy - Also enable test for `aten::full` and rename kwargs wranglers to input wrangers. -Note: logcumsumexp is not numerically stable, but I will leave the fix as a TODO. Reference: https://github.com/pytorch/pytorch/pull/36308/files
1 parent 65aa427 commit 339a895

File tree

4 files changed

+117
-53
lines changed

4 files changed

+117
-53
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
186186
raise NotImplementedError()
187187

188188

189-
# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
189+
# @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
190190
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
191191
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
192192

193193
# TODO(justinchuby): Make dim optional, keepdim bool
194194
return op.ReduceMax(self, dim, keepdims=keepdim)
195195

196196

197-
# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
197+
# @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
198198
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
199199
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
200200

@@ -2575,52 +2575,68 @@ def aten_linspace(start: float, end: float, steps: int) -> TensorType:
25752575
raise NotImplementedError()
25762576

25772577

2578-
def aten_log(self: TensorType) -> TensorType:
2578+
@torch_op("log")
2579+
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
25792580
# log(Tensor self) -> Tensor
25802581

2581-
raise NotImplementedError()
2582+
return op.Log(self)
25822583

25832584

2584-
def aten_log10(self: TensorType) -> TensorType:
2585+
@torch_op("aten::log10")
2586+
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
25852587
# log10(Tensor self) -> Tensor
25862588

2587-
raise NotImplementedError()
2589+
return op.Div(op.Log(self), op.Log(10.0))
25882590

25892591

2590-
def aten_log1p(self: TensorType) -> TensorType:
2592+
@torch_op("aten::log1p")
2593+
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
25912594
# log1p(Tensor self) -> Tensor
25922595

2593-
raise NotImplementedError()
2596+
return op.Log(op.Add(self, 1.0))
25942597

25952598

2596-
def aten_log2(self: TensorType) -> TensorType:
2599+
@torch_op("aten::log2")
2600+
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
25972601
# log2(Tensor self) -> Tensor
25982602

2599-
raise NotImplementedError()
2603+
return op.Div(op.Log(self), op.Log(2.0))
26002604

26012605

2602-
def aten_logaddexp(self: TensorType, other: TensorType) -> TensorType:
2606+
@torch_op("aten::logaddexp")
2607+
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
26032608
# logaddexp(Tensor self, Tensor other) -> Tensor
26042609

2605-
raise NotImplementedError()
2610+
return op.Log(op.Add(op.Exp(self), op.Exp(other)))
26062611

26072612

2608-
def aten_logaddexp2(self: TensorType, other: TensorType) -> TensorType:
2613+
@torch_op("aten::logaddexp2")
2614+
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
26092615
# logaddexp2(Tensor self, Tensor other) -> Tensor
2616+
summation = op.Add(op.Pow(2.0, self), op.Pow(2.0, other))
26102617

2611-
raise NotImplementedError()
2618+
return op.Div(op.Log(summation), op.Log(2.0))
26122619

26132620

2614-
def aten_logcumsumexp(self: TensorType, dim: int) -> TensorType:
2621+
@torch_op("aten::logcumsumexp")
2622+
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: INT64) -> TFloatOrBFloat16:
26152623
# logcumsumexp(Tensor self, int dim) -> Tensor
26162624

2617-
raise NotImplementedError()
2625+
if op.Size(op.Shape(self)) == 0:
2626+
# A scalar
2627+
result = self
2628+
else:
2629+
# FIXME(justinchuby): Ensure numerical stability
2630+
result = op.Log(op.CumSum(op.Exp(self), dim))
2631+
2632+
return result
26182633

26192634

2620-
def aten_logdet(self: TensorType) -> TensorType:
2635+
@torch_op("aten::logdet")
2636+
def aten_logdet(self: TFloat) -> TFloat:
26212637
# logdet(Tensor self) -> Tensor
26222638

2623-
raise NotImplementedError()
2639+
return op.Log(op.Det(self))
26242640

26252641

26262642
@torch_op("aten::logical_and")
@@ -2663,10 +2679,11 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
26632679
raise NotImplementedError()
26642680

26652681

2666-
def aten_logsumexp(self: TensorType, dim: Sequence[int], keepdim: bool = False) -> TensorType:
2682+
@torch_op("aten::logsumexp", trace_only=True) # FIXME(#249): Script when CI uses onnx 1.13
2683+
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
26672684
# logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
26682685

2669-
raise NotImplementedError()
2686+
return op.ReduceLogSumExp(self, dim, keepdims=keepdim)
26702687

26712688

26722689
def aten_lshift(self: TensorType, other: TensorType) -> TensorType:
@@ -5035,12 +5052,6 @@ def aten_where(self: TTensor, condition: BOOL, other: TTensor) -> TTensor:
50355052
return op.Where(condition, self, other)
50365053

50375054

5038-
def aten_xlogy(self: TensorType, other: TensorType) -> TensorType:
5039-
# xlogy.Tensor(Tensor self, Tensor other) -> Tensor
5040-
5041-
raise NotImplementedError()
5042-
5043-
50445055
def aten_xor(self: TensorType, other: TensorType) -> TensorType:
50455056
# __xor__.Tensor(Tensor self, Tensor other) -> Tensor
50465057

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,11 @@ def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) ->
481481
return result
482482

483483

484-
def aten_log_sigmoid(self: TensorType) -> TensorType:
484+
@torch_op("aten::log_sigmoid")
485+
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
485486
# log_sigmoid(Tensor self) -> Tensor
486487

487-
raise NotImplementedError()
488+
return op.Log(op.Sigmoid(self))
488489

489490

490491
def aten_log_sigmoid_backward(

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from typing import Optional, Sequence
1515

16+
from onnxscript.function_libs.torch_aten.registration import torch_op
17+
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
18+
from onnxscript.onnx_opset import opset18 as op
1619
from onnxscript.onnx_types import TensorType
1720

1821

@@ -344,10 +347,22 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:
344347
raise NotImplementedError()
345348

346349

347-
def aten_special_xlogy(self: TensorType, other: TensorType) -> TensorType:
350+
@torch_op("aten::xlogy")
351+
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
348352
# special_xlogy(Tensor self, Tensor other) -> Tensor
349353

350-
raise NotImplementedError()
354+
# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
355+
# out := {
356+
# NaN if other == NaN
357+
# 0 if self == 0
358+
# self * log(other) otherwise
359+
# }
360+
361+
nans = op.IsNaN(other)
362+
zeros = op.Equal(self, 0)
363+
xlogy = op.Mul(self, op.Log(other))
364+
xlogy_with_nans = op.Where(nans, other, xlogy)
365+
return op.Where(zeros, self, xlogy_with_nans)
351366

352367

353368
def aten_special_zeta(self: TensorType, other: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import onnxscript
1717
from onnxscript.function_libs.torch_aten.ops import core as core_ops
1818
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops
19+
from onnxscript.function_libs.torch_aten.ops import special as special_ops
1920

2021
T = TypeVar("T")
2122

@@ -161,20 +162,39 @@ def duplicate_opinfo(opinfos: list[opinfo_core.OpInfo], name: str, new_names: tu
161162
# Modify this section ##########################################################
162163

163164

164-
def _amax_amin_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
165+
def _amax_amin_input_wrangler(
166+
args: list[Any], kwargs: dict[str, Any]
167+
) -> tuple[list[Any], dict[str, Any]]:
165168
if "dim" not in kwargs:
166169
kwargs["dim"] = None
167-
return kwargs
170+
return args, kwargs
168171

169172

170-
def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
173+
def _full_input_wrangler(
174+
args: list[Any], kwargs: dict[str, Any]
175+
) -> tuple[list[Any], dict[str, Any]]:
176+
# Remove the self argument
177+
args.pop(0)
178+
return args, kwargs
179+
180+
181+
def _upsample_input_wrangler(
182+
args: list[Any], kwargs: dict[str, Any]
183+
) -> tuple[list[Any], dict[str, Any]]:
171184
if "scale_factor" in kwargs:
172185
kwargs["scales_h"] = kwargs["scale_factor"]
173186
kwargs["scales_w"] = kwargs["scale_factor"]
174187
del kwargs["scale_factor"]
175188
if "size" in kwargs:
176189
kwargs["size"] = np.array(kwargs["size"])
177-
return kwargs
190+
return args, kwargs
191+
192+
193+
def _logcumsumexp_input_wrangler(
194+
args: list[Any], kwargs: dict[str, Any]
195+
) -> tuple[list[Any], dict[str, Any]]:
196+
kwargs["keepdim"] = args.pop()
197+
return args, kwargs
178198

179199

180200
# Ops to be tested for numerical consistency between onnx and pytorch
@@ -185,16 +205,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
185205
| Callable[..., Any]
186206
| tuple[
187207
onnxscript.OnnxFunction | Callable[..., Any],
188-
Callable[[dict[str, Any]], dict[str, Any]],
208+
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
189209
],
190210
] = {
191211
"abs": core_ops.aten_abs,
192212
"acos": core_ops.aten_acos,
193213
"acosh": core_ops.aten_acosh,
194214
"add": core_ops.aten_add,
195215
"addmm": core_ops.aten_addmm,
196-
"amax": (core_ops.aten_amax, _amax_amin_kwargs_wrangler),
197-
"amin": (core_ops.aten_amin, _amax_amin_kwargs_wrangler),
216+
"amax": (core_ops.aten_amax, _amax_amin_input_wrangler),
217+
"amin": (core_ops.aten_amin, _amax_amin_input_wrangler),
198218
"arange_start_step": core_ops.aten_arange_start_step,
199219
"arange_start": core_ops.aten_arange_start,
200220
"arange": core_ops.aten_arange,
@@ -219,11 +239,20 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
219239
"expand": core_ops.aten_expand,
220240
"erf": core_ops.aten_erf,
221241
"fmod": core_ops.aten_fmod,
222-
# TODO(justinchuby): Test aten::full
242+
"full": (core_ops.aten_full, _full_input_wrangler),
223243
"full_like": core_ops.aten_full_like,
224244
"gt": core_ops.aten_gt,
225245
"index_select": core_ops.aten_index_select,
226246
"isinf": core_ops.aten_isinf,
247+
"log": core_ops.aten_log,
248+
"log10": core_ops.aten_log10,
249+
"log1p": core_ops.aten_log1p,
250+
"log2": core_ops.aten_log2,
251+
"logaddexp": core_ops.aten_logaddexp,
252+
"logaddexp2": core_ops.aten_logaddexp2,
253+
"logcumsumexp": core_ops.aten_logcumsumexp,
254+
"logdet": core_ops.aten_logdet,
255+
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
227256
"lt": core_ops.aten_lt,
228257
"matmul": core_ops.aten_matmul,
229258
"mm": core_ops.aten_mm,
@@ -237,12 +266,13 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
237266
"nn.functional.elu": nn_ops.aten_elu,
238267
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
239268
"nn.functional.linear": nn_ops.aten_linear,
269+
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,
240270
"nn.functional.relu": nn_ops.aten_relu,
241271
"nn.functional.relu6": nn_ops.aten_relu6,
242272
"nn.functional.selu": core_ops.aten_selu,
243273
"nn.functional.upsample_nearest2d": (
244274
nn_ops.aten_upsample_nearest2d,
245-
_upsample_kwargs_wrangler,
275+
_upsample_input_wrangler,
246276
),
247277
"nonzero": core_ops.aten_nonzero,
248278
"ones_like": core_ops.aten_ones_like,
@@ -267,6 +297,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
267297
"unsqueeze": core_ops.aten_unsqueeze,
268298
"view": core_ops.aten_view,
269299
"where": core_ops.aten_where,
300+
"xlogy": special_ops.aten_special_xlogy,
270301
"zeros": core_ops.aten_zeros,
271302
"zeros_like": core_ops.aten_zeros_like,
272303
}
@@ -276,7 +307,9 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
276307
EXPECTED_SKIPS_OR_FAILS = (
277308
xfail("amax", reason="ONNX Runtime 1.13 does not support ReduceMax-18"),
278309
xfail("amin", reason="ONNX Runtime 1.13 does not support ReduceMin-18"),
279-
skip("clamp", reason="Enable when onnxscript supports optional inputs"),
310+
skip("clamp", reason="enable when onnxscript supports optional inputs"),
311+
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
312+
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
280313
xfail(
281314
"nn.functional.linear",
282315
reason="ONNX Runtime thinks the graph is invalid",
@@ -358,23 +391,25 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
358391

359392
duplicate_opinfo(
360393
OPS_DB,
361-
"nn.functional.upsample_nearest",
394+
"arange",
362395
(
363-
"nn.functional.upsample_nearest1d",
364-
"nn.functional.upsample_nearest2d",
365-
"nn.functional.upsample_nearest3d",
396+
"arange_start",
397+
"arange_start_step",
366398
),
367399
)
368400

369401
duplicate_opinfo(
370402
OPS_DB,
371-
"arange",
403+
"nn.functional.upsample_nearest",
372404
(
373-
"arange_start",
374-
"arange_start_step",
405+
"nn.functional.upsample_nearest1d",
406+
"nn.functional.upsample_nearest2d",
407+
"nn.functional.upsample_nearest3d",
375408
),
376409
)
377410

411+
duplicate_opinfo(OPS_DB, "new_full", ("full",))
412+
378413

379414
# END OF SECTION TO MODIFY #####################################################
380415

@@ -477,13 +512,13 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
477512
)
478513

479514
onnx_function_and_wrangler = OPINFO_FUNCTION_MAPPING[op.name]
480-
kwarg_wrangler = None
515+
input_wrangler = None
481516
if isinstance(onnx_function_and_wrangler, tuple):
482-
# Obtain the kwarg_wrangler that manipulates the OpInfo inputs
517+
# Obtain the input_wrangler that manipulates the OpInfo inputs
483518
# to match the aten operator signature
484519
# An example is nn.functional.upsample_nearest2d, which has a different signature
485520
# than the aten operator upsample_nearest2d
486-
onnx_function, kwarg_wrangler = onnx_function_and_wrangler
521+
onnx_function, input_wrangler = onnx_function_and_wrangler
487522
else:
488523
assert callable(onnx_function_and_wrangler)
489524
onnx_function = onnx_function_and_wrangler
@@ -503,8 +538,8 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
503538
continue
504539
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
505540
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
506-
if kwarg_wrangler:
507-
kwargs_onnx = kwarg_wrangler(kwargs_onnx)
541+
if input_wrangler:
542+
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
508543
torch_output = op(*inputs, **cpu_sample.kwargs)
509544
function_output = onnx_function(*input_onnx, **kwargs_onnx)
510545

@@ -524,7 +559,9 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
524559
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
525560
torch.testing.assert_close(
526561
torch.tensor(function_output),
527-
torch.tensor(torch_output),
562+
torch_output
563+
if isinstance(torch_output, torch.Tensor)
564+
else torch.tensor(torch_output),
528565
rtol=rtol,
529566
atol=atol,
530567
)

0 commit comments

Comments
 (0)