diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index ad4aaea278..d3fce174e1 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1398,10 +1398,14 @@ def aten_convolution_overrideable( raise NotImplementedError() -def aten_copy(self: TensorType, src: TensorType, non_blocking: bool = False) -> TensorType: +@torch_op("aten::copy") +def aten_copy( + self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument +) -> TTensor: # copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor - raise NotImplementedError() + self = op.Identity(src) + return self def aten_copysign(self: TensorType, other: TensorType) -> TensorType: @@ -1948,10 +1952,17 @@ def aten_empty_quantized( raise NotImplementedError() -def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: +@torch_op("aten::empty_strided") +def aten_empty_strided( + size: INT64, stride: INT64 # pylint: disable=unused-argument +) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - raise NotImplementedError() + # using Zeros to simulate empty() + size = op.Cast(size, to=INT64.dtype) + zero = op.Constant(value_float=0.0) + + return op.Expand(zero, size) @torch_op("aten::eq") @@ -2172,10 +2183,15 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -def aten_fill(self: TensorType, value: TensorType) -> TensorType: +@torch_op("aten::fill") +def aten_fill(self: TTensor, value: TTensor) -> TTensor: # fill.Tensor(Tensor self, Tensor value) -> Tensor - raise NotImplementedError() + # after fill, the self Tensor should keep origianl type + shape = op.Shape(self) + expanded = op.Expand(value, shape) + result = op.CastLike(expanded, self) + return result def aten_fix(self: TensorType) -> TensorType: @@ -3799,12 +3815,14 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::native_dropout") def aten_native_dropout( - input: TensorType, p: float, train: Optional[bool] -) -> tuple[TensorType, TensorType]: + input: TFloatOrBFloat16, p: float, train: bool = True +) -> Tuple[TFloatOrBFloat16, BOOL]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) - raise NotImplementedError() + result, mask = op.Dropout(input, p, train) + return result, mask def aten_native_dropout_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index eb96300d05..396e45b4e0 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -264,18 +264,21 @@ def _where_input_wrangler( "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, "clone": core_ops.aten_clone, + # "copy": core_ops.aten_copy, # copy is not in OPS_DB "cos": core_ops.aten_cos, "cosh": core_ops.aten_cosh, # "detach": core_ops.aten_detach, # detach is not in OP-TEST-DB "div": core_ops.aten_div, "dot": core_ops.aten_dot, "empty": core_ops.aten_empty, + # "empty_strided": core_ops.aten_empty_strided, # empty_strided is not in OPS_DB "eq": core_ops.aten_eq, "equal": core_ops.aten_equal, "exp": core_ops.aten_exp, "exp2": core_ops.aten_exp2, "expand": core_ops.aten_expand, "erf": core_ops.aten_erf, + "fill": core_ops.aten_fill, "fmod": core_ops.aten_fmod, "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, @@ -300,6 +303,7 @@ def _where_input_wrangler( "minimum": core_ops.aten_minimum, "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, + # "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_empty": core_ops.aten_new_empty,