Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def separate_input_attributes_from_arguments(
else:
continue
elif param.required:
raise TypeError(f"Required input '{param}' was not provided")
# raise TypeError(f"Required input '{param}' was not provided")
print(f"Required input '{param}' was not provided")

return onnx_inputs, onnx_attributes
32 changes: 27 additions & 5 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,30 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


def aten_all(self: TensorType) -> TensorType:
@torch_op("aten::all")
def aten_all(
self: TTensor, dim: Optional[int] = None, keepdim: Optional[bool] = False
) -> TTensor:
"""all(Tensor self) -> Tensor"""

raise NotImplementedError()
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

self_bool = op.Cast(self, to=BOOL.dtype)
self_int = op.Cast(self_bool, to=INT64.dtype)

if op.OptionalHasElement(dim):
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
result_int = op.ReduceMin(self_int, dims, keepdims=keepdim)
else:
result_int = op.ReduceMin(self_int, keepdims=keepdim)
result = op.Cast(result_int, to=BOOL.dtype)

if self_rank == 0:
result = op.Squeeze(result)

return result


def aten_allclose(
Expand Down Expand Up @@ -2863,10 +2883,11 @@ def aten_isclose(
raise NotImplementedError()


@torch_op("aten::isfinite")
def aten_isfinite(self: TensorType) -> TensorType:
"""isfinite(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Not(op.IsInf(self))


@torch_op("aten::isinf")
Expand Down Expand Up @@ -5131,10 +5152,11 @@ def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> Tensor
raise NotImplementedError()


def aten_split_with_sizes(self: TensorType, split_sizes: INT64, dim: int = 0) -> TensorType:
@torch_op("aten::split_with_sizes")
def aten_split_with_sizes(self: TTensor, split_sizes: INT64, dim: int = 0) -> TTensor:
"""split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]"""

raise NotImplementedError()
return op.SplitToSequence(self, split_sizes, axis=dim)


def aten_split_with_sizes_copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _where_input_wrangler(
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
],
] = {
"all": core_ops.aten_all,
"abs": core_ops.aten_abs,
"acos": core_ops.aten_acos,
"acosh": core_ops.aten_acosh,
Expand Down Expand Up @@ -323,6 +324,7 @@ def _where_input_wrangler(
"full_like": core_ops.aten_full_like,
"ge": core_ops.aten_ge,
"gt": core_ops.aten_gt,
"isfinite": core_ops.aten_isfinite,
"isinf": core_ops.aten_isinf,
"log": core_ops.aten_log,
"le": core_ops.aten_le,
Expand Down Expand Up @@ -380,6 +382,7 @@ def _where_input_wrangler(
"sin": core_ops.aten_sin,
"sinh": core_ops.aten_sinh,
"softmax": special_ops.aten_special_softmax,
"split_with_sizes": core_ops.aten_split_with_sizes,
"split": core_ops.aten_split,
"sqrt": core_ops.aten_sqrt,
"stack": core_ops.aten_stack,
Expand Down Expand Up @@ -759,6 +762,9 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
if i == 9:
print(i)

skip_reason = _should_skip_test_sample(op.name, cpu_sample)
if skip_reason is not None:
# Cannot use self.skip because pytest would skip the entire test
Expand Down