diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0cbbf4507f..8bebcecbeb 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4297,16 +4297,40 @@ def aten_new_full( return op.Expand(fill_value, size) -def aten_new_ones(self: TensorType, size: INT64) -> TensorType: +@torch_op("aten::new_ones") +def aten_new_ones(self: TReal, size: INT64) -> TReal: # pylint: disable=unused-argument """new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + one = op.Constant(value_float=1.0) + return op.Expand(one, size) + + +@torch_op("aten::new_ones", overload=True) +def aten_new_ones_dtype( + self: TReal, size: INT64, dtype: int # pylint: disable=unused-argument +) -> TReal: + + one = op.Constant(value_float=1.0) + one = op.Cast(one, to=dtype) + return op.Expand(one, size) -def aten_new_zeros(self: TensorType, size: INT64) -> TensorType: +@torch_op("aten::new_zeros") +def aten_new_zeros(self: TReal, size: INT64) -> TReal: # pylint: disable=unused-argument """new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + zero = op.Constant(value_float=0.0) + return op.Expand(zero, size) + + +@torch_op("aten::new_zeros", overload=True) +def aten_new_zeros_dtype( + self: TReal, size: INT64, dtype: int # pylint: disable=unused-argument +) -> TReal: + + zero = op.Constant(value_float=0.0) + zero = op.Cast(zero, to=dtype) + return op.Expand(zero, size) def aten_nextafter(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index e5250a571c..bc3d5267e5 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -431,6 +431,10 @@ def _where_input_wrangler( "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_full": core_ops.aten_new_full, + "new_ones_dtype": core_ops.aten_new_ones_dtype, + "new_ones": core_ops.aten_new_ones, + "new_zeros_dtype": core_ops.aten_new_zeros_dtype, + "new_zeros": core_ops.aten_new_zeros, "nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d, "nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d, "nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d, @@ -589,6 +593,26 @@ def _where_input_wrangler( reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", test_class_name="TestOutputConsistencyFullGraph", ), + xfail( + "new_ones", + reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", + test_class_name="TestOutputConsistencyFullGraph", + ), + xfail( + "new_ones_dtype", + reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", + test_class_name="TestOutputConsistencyFullGraph", + ), + xfail( + "new_zeros", + reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", + test_class_name="TestOutputConsistencyFullGraph", + ), + xfail( + "new_zeros_dtype", + reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", + test_class_name="TestOutputConsistencyFullGraph", + ), xfail( "nn.functional.adaptive_avg_pool1d", reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_adaptive_avg_pool1d: failed validating the check: !(it.GetName().empty())'", @@ -705,6 +729,26 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), + skip( + "new_ones", + matcher=lambda sample: sample.kwargs.get("dtype") is not None, + reason="", + ), + skip( + "new_ones_dtype", + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="", + ), + skip( + "new_zeros", + matcher=lambda sample: sample.kwargs.get("dtype") is not None, + reason="", + ), + skip( + "new_zeros_dtype", + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="", + ), skip( "nonzero", matcher=lambda sample: sample.kwargs.get("as_tuple") is not None, @@ -807,6 +851,10 @@ def _where_input_wrangler( duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) +duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",)) + +duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",)) + duplicate_opinfo(OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",)) duplicate_opinfo(