Skip to content

Commit 3c1cf03

Browse files
a-gardner1ezyang
authored andcommitted
Add fake impl for aten.unique_dim (#126561)
Follow-up to #113118 and #124306. Developed in coordination with the solution to microsoft/onnxscript#1547 This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None. Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547): ```python import onnx import onnxruntime as ort import logging import numpy as np onnx_program = torch.onnx.dynamo_export( lambda x: torch.unique(x, dim=0, return_inverse=True), torch.arange(10), export_options=torch.onnx.ExportOptions( dynamic_shapes=True, diagnostic_options=torch.onnx.DiagnosticOptions( verbosity_level=logging.DEBUG))) onnx_program.save("torch_unique.onnx") onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10)) onnx_outputs = onnx_program(*onnx_inputs) loaded_onnx_program = onnx.load("torch_unique.onnx") onnx.checker.check_model(loaded_onnx_program) ort_session = ort.InferenceSession("torch_unique.onnx") inputs = np.random.randint(0, 10, 10) print(f"Inputs: {inputs}") outputs = ort_session.run(None, { "l_x_": inputs }) print(f"Outputs: {outputs}") print("Success") ``` Co-authored-by: Edward Z. Yang <[email protected]> Pull Request resolved: #126561 Approved by: https://github.com/ezyang
1 parent 25447ba commit 3c1cf03

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,8 +2522,8 @@ def map_to_fake(e):
25222522
or name in sometimes_dynamic_output_op_test
25232523
)
25242524
self.assertTrue(
2525-
mode.shape_env is None
2526-
or not mode.shape_env.allow_dynamic_output_shape_ops
2525+
fake_mode.shape_env is None
2526+
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
25272527
or name not in supported_dynamic_output_op_tests
25282528
)
25292529
except torch._subclasses.fake_tensor.DataDependentOutputException:

test/test_proxy_tensor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,6 @@ def f(t):
20032003
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
20042004
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
20052005
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
2006-
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
20072006

20082007
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
20092008

@@ -2034,8 +2033,6 @@ def f(t):
20342033
inplace_symbolic_tensor_failures = {
20352034
# bugs
20362035
xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double
2037-
# decomp not implemented
2038-
xfail('unique', ''),
20392036
}
20402037

20412038
out_symbolic_tensor_failures = {

torch/_subclasses/fake_impls.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,8 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
258258
raise DynamicOutputShapeException(func)
259259

260260

261-
@register_op_impl(aten._unique2.default)
262-
def unique2(
263-
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
261+
def _unique(
262+
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
264263
):
265264
if (
266265
fake_mode.shape_env is None
@@ -269,7 +268,8 @@ def unique2(
269268
# Without symints/symfloats, cannot handle this
270269
raise DynamicOutputShapeException(func)
271270

272-
if (nnz := arg.unique_memo) is None:
271+
# Do not use a memo for unique_dim
272+
if dim is not None or (nnz := arg.unique_memo) is None:
273273
# Avoid importing sympy at a module level
274274
from torch.fx.experimental.symbolic_shapes import (
275275
_constrain_range_for_size,
@@ -291,28 +291,59 @@ def unique2(
291291

292292
maxval = sys.maxsize - 1
293293

294-
if not has_free_symbols(arg.numel()):
295-
maxval = int(arg.numel())
294+
numel = arg.numel() if dim is None else arg.size(dim)
295+
if not has_free_symbols(numel):
296+
maxval = int(numel)
296297

297298
_constrain_range_for_size(nnz, max=maxval)
298299

299-
arg.unique_memo = nnz
300+
if dim is None:
301+
arg.unique_memo = nnz
300302

301-
ret = [arg.new_empty((nnz,))]
303+
if dim is None:
304+
ret = [arg.new_empty((nnz,))]
305+
else:
306+
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
302307

303-
if return_inverse:
304-
ret.append(torch.empty_like(arg))
308+
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
309+
if return_inverse or return_if_dim_and_cpu:
310+
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
305311
else:
306-
ret.append(arg.new_empty(0))
312+
inverse = arg.new_empty(0)
313+
ret.append(inverse)
307314

308-
if return_counts:
309-
ret.append(torch.empty_like(arg))
315+
if return_counts or return_if_dim_and_cpu:
316+
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
310317
else:
311-
ret.append(arg.new_empty(0))
318+
counts = arg.new_empty(0)
319+
ret.append(counts)
312320

313321
return tuple(ret)
314322

315323

324+
@register_op_impl(aten._unique2.default)
325+
def unique2(
326+
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
327+
):
328+
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
329+
330+
331+
@register_op_impl(aten.unique_dim.default)
332+
def unique_dim(
333+
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
334+
):
335+
return _unique(
336+
fake_mode,
337+
func,
338+
arg,
339+
# normalize dim to be non-negative
340+
dim if dim >= 0 else dim % max(arg.ndim, 1),
341+
sorted,
342+
return_inverse,
343+
return_counts,
344+
)
345+
346+
316347
@register_op_impl(aten.repeat_interleave.Tensor)
317348
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
318349
if output_size is None:

0 commit comments

Comments
 (0)