@@ -1531,10 +1531,29 @@ def aten_cov(
1531
1531
raise NotImplementedError ()
1532
1532
1533
1533
1534
- def aten_cross (self : TensorType , other : TensorType , dim : Optional [int ] = None ) -> TensorType :
1534
+ @torch_op ("aten::cross" )
1535
+ def aten_cross (self : TTensor , other : TTensor , dim : int = - 1 ) -> TTensor :
1535
1536
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""
1536
1537
1537
- raise NotImplementedError ()
1538
+ zero = op .Constant (value_ints = [0 ])
1539
+ one = op .Constant (value_ints = [1 ])
1540
+ two = op .Constant (value_ints = [2 ])
1541
+ three = op .Constant (value_ints = [3 ])
1542
+ axes = op .Expand (dim , op .Constant (value_ints = [1 ]))
1543
+
1544
+ # Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
1545
+ a1 = op .Slice (self , zero , one , axes )
1546
+ a2 = op .Slice (self , one , two , axes )
1547
+ a3 = op .Slice (self , two , three , axes )
1548
+ b1 = op .Slice (other , zero , one , axes )
1549
+ b2 = op .Slice (other , one , two , axes )
1550
+ b3 = op .Slice (other , two , three , axes )
1551
+ # Broadcasting is implicitly supported by Mul
1552
+ c1 = op .Sub (op .Mul (a2 , b3 ), op .Mul (a3 , b2 ))
1553
+ c2 = op .Sub (op .Mul (a3 , b1 ), op .Mul (a1 , b3 ))
1554
+ c3 = op .Sub (op .Mul (a1 , b2 ), op .Mul (a2 , b1 ))
1555
+
1556
+ return op .Concat (c1 , c2 , c3 , axis = dim )
1538
1557
1539
1558
1540
1559
def aten_crow_indices (self : TensorType ) -> TensorType :
@@ -2009,7 +2028,6 @@ def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
2009
2028
2010
2029
@torch_op ("aten::empty_like" , overload = True )
2011
2030
def _aten_empty_like_onnx (self : TTensor , zero ) -> TTensor :
2012
-
2013
2031
shape = op .Shape (self )
2014
2032
return op .Expand (zero , shape )
2015
2033
@@ -4236,7 +4254,6 @@ def aten_ones_like(self: TTensor, dtype: int = -1) -> TTensor:
4236
4254
4237
4255
@torch_op ("aten::ones_like" , overload = True )
4238
4256
def _aten_ones_like_onnx (self : TTensor , one ) -> TTensor :
4239
-
4240
4257
shape = op .Shape (self )
4241
4258
return op .Expand (one , shape )
4242
4259
@@ -5790,6 +5807,5 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:
5790
5807
5791
5808
@torch_op ("aten::zeros_like" , overload = True )
5792
5809
def _aten_zeros_like_onnx (self : TTensor , zero ) -> TTensor :
5793
-
5794
5810
shape = op .Shape (self )
5795
5811
return op .Expand (zero , shape )
0 commit comments