Skip to content

Commit 510db96

Browse files
authored
feat(atenlib): op(cross) (#512)
onnx/onnx#2683
1 parent 17915ed commit 510db96

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,10 +1531,29 @@ def aten_cov(
15311531
raise NotImplementedError()
15321532

15331533

1534-
def aten_cross(self: TensorType, other: TensorType, dim: Optional[int] = None) -> TensorType:
1534+
@torch_op("aten::cross")
1535+
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
15351536
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""
15361537

1537-
raise NotImplementedError()
1538+
zero = op.Constant(value_ints=[0])
1539+
one = op.Constant(value_ints=[1])
1540+
two = op.Constant(value_ints=[2])
1541+
three = op.Constant(value_ints=[3])
1542+
axes = op.Expand(dim, op.Constant(value_ints=[1]))
1543+
1544+
# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
1545+
a1 = op.Slice(self, zero, one, axes)
1546+
a2 = op.Slice(self, one, two, axes)
1547+
a3 = op.Slice(self, two, three, axes)
1548+
b1 = op.Slice(other, zero, one, axes)
1549+
b2 = op.Slice(other, one, two, axes)
1550+
b3 = op.Slice(other, two, three, axes)
1551+
# Broadcasting is implicitly supported by Mul
1552+
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
1553+
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
1554+
c3 = op.Sub(op.Mul(a1, b2), op.Mul(a2, b1))
1555+
1556+
return op.Concat(c1, c2, c3, axis=dim)
15381557

15391558

15401559
def aten_crow_indices(self: TensorType) -> TensorType:
@@ -2009,7 +2028,6 @@ def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
20092028

20102029
@torch_op("aten::empty_like", overload=True)
20112030
def _aten_empty_like_onnx(self: TTensor, zero) -> TTensor:
2012-
20132031
shape = op.Shape(self)
20142032
return op.Expand(zero, shape)
20152033

@@ -4236,7 +4254,6 @@ def aten_ones_like(self: TTensor, dtype: int = -1) -> TTensor:
42364254

42374255
@torch_op("aten::ones_like", overload=True)
42384256
def _aten_ones_like_onnx(self: TTensor, one) -> TTensor:
4239-
42404257
shape = op.Shape(self)
42414258
return op.Expand(one, shape)
42424259

@@ -5790,6 +5807,5 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:
57905807

57915808
@torch_op("aten::zeros_like", overload=True)
57925809
def _aten_zeros_like_onnx(self: TTensor, zero) -> TTensor:
5793-
57945810
shape = op.Shape(self)
57955811
return op.Expand(zero, shape)

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def _where_input_wrangler(
305305
# "copy": core_ops.aten_copy, # copy is not in OPS_DB
306306
"cos": core_ops.aten_cos,
307307
"cosh": core_ops.aten_cosh,
308+
"cross": core_ops.aten_cross,
308309
# "detach": core_ops.aten_detach, # detach is not in OP-TEST-DB
309310
"div": core_ops.aten_div,
310311
"dot": core_ops.aten_dot,
@@ -751,7 +752,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
751752
assert callable(onnx_function_and_wrangler)
752753
onnx_function = onnx_function_and_wrangler
753754

754-
for (i, cpu_sample) in enumerate(samples):
755+
for i, cpu_sample in enumerate(samples):
755756
inputs = (cpu_sample.input, *cpu_sample.args)
756757
# Provide the repr to subtest because tensors are not serializable in parallel test runs
757758
with self.subTest(

0 commit comments

Comments
 (0)