Skip to content

Commit 697c6ac

Browse files
authored
Add Op(multinomial) | torchlib(feat) (#1032)
~~1. xfail reason: ONNX spec expects the first dim is batch size, and the whole input should be 2D.~~ ~~2. extra opinfo reason: We would like to have a pure 2D datasets to make sure the op is functional.~~ ATen supports 1D and 2D, but ONNX only supports 2D, so Unsqueeze would be used when it's 1D input.
1 parent f270dc1 commit 697c6ac

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4820,15 +4820,24 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
48204820
return op.And(self, other)
48214821

48224822

4823+
@torch_op("aten::multinomial")
48234824
def aten_multinomial(
4824-
self: TensorType,
4825+
self: TFloat,
48254826
num_samples: int,
4826-
replacement: bool = False,
4827-
generator: Optional[str] = None,
4828-
) -> TensorType:
4827+
replacement: bool = False, # pylint: disable=unused-argument
4828+
) -> TInt:
48294829
"""multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor"""
4830-
4831-
raise NotImplementedError()
4830+
# ONNX Multinomial doesn't support 1D input
4831+
if op.Size(op.Shape(self)) == 1:
4832+
unsqueezed_input = op.Unsqueeze(self, axes=0)
4833+
else:
4834+
unsqueezed_input = self
4835+
# ONNX multinomial expects log probability
4836+
log_input = op.Log(unsqueezed_input)
4837+
result = op.Multinomial(log_input, dtype=INT64.dtype, sample_size=num_samples)
4838+
if op.Size(op.Shape(self)) == 1:
4839+
result = op.Squeeze(result)
4840+
return result
48324841

48334842

48344843
def aten_multiply(self: TensorType, other: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,7 @@ def _where_input_wrangler(
15451545
matcher=lambda sample: len(sample.args) > 0,
15461546
reason="this ATen overload only supports one tensor as input by design",
15471547
),
1548+
TorchLibOpInfo("multinomial", core_ops.aten_multinomial, nondeterministic=True),
15481549
TorchLibOpInfo(
15491550
# Custom from extra_opinfo
15501551
"ops.aten.max_pool1d",

0 commit comments

Comments
 (0)