Skip to content

Commit b2c19a7

Browse files
authored
Add BOOL variants to bitwise ops | fix(torchlib) (#907)
In torchbench we see errors like `[ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from bench_dynamo_onnx_model/attention_is_all_you_need_pytorch/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(bool)' of input parameter (unsqueeze_1) of operator (BitwiseAnd) in node (n0__9) is invalid.` This change fixes them by registering the Boolean variants from the exiting logical* implementations. The reason for not creating new bitwise* variants is that the implementation is exactly the same as the logical* ones, which are covered by test cases. The bool bitwise* BOOL variants, if implemented separately, besides completely identical logic, would require test case duplication, which imo is not worth it.
1 parent 4efa764 commit b2c19a7

File tree

1 file changed

+56
-13
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+56
-13
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,9 +1007,17 @@ def aten_binomial(
10071007
raise NotImplementedError()
10081008

10091009

1010-
@torch_op("aten::bitwise_and")
1010+
@torch_op(
1011+
(
1012+
"aten::bitwise_and",
1013+
"aten::bitwise_and.Tensor",
1014+
"aten::bitwise_and.Scalar",
1015+
"aten::bitwise_and.Scalar_Tensor",
1016+
)
1017+
)
10111018
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
10121019
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
1020+
# logical_and implements the BOOL variant
10131021

10141022
return op.BitwiseAnd(self, other)
10151023

@@ -1024,19 +1032,22 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
10241032
@torch_op("aten::bitwise_not")
10251033
def aten_bitwise_not(self: TInt) -> TInt:
10261034
"""bitwise_not(Tensor self) -> Tensor"""
1035+
# logical_not implements the BOOL variant
10271036

10281037
return op.BitwiseNot(self)
10291038

10301039

1031-
@torch_op("aten::bitwise_not")
1032-
def aten_bitwise_not_bool(self: BOOL) -> BOOL:
1033-
"""bitwise_not(Tensor self) -> Tensor"""
1034-
return op.Not(self)
1035-
1036-
1037-
@torch_op("aten::bitwise_or")
1040+
@torch_op(
1041+
(
1042+
"aten::bitwise_or",
1043+
"aten::bitwise_or.Tensor",
1044+
"aten::bitwise_or.Scalar",
1045+
"aten::bitwise_or.Scalar_Tensor",
1046+
)
1047+
)
10381048
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
10391049
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
1050+
# logical_or implements the BOOL variant
10401051

10411052
return op.BitwiseOr(self, other)
10421053

@@ -1048,9 +1059,17 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
10481059
return op.BitShift(self, other, direction="RIGHT")
10491060

10501061

1051-
@torch_op("aten::bitwise_xor")
1062+
@torch_op(
1063+
(
1064+
"aten::bitwise_xor",
1065+
"aten::bitwise_xor.Tensor",
1066+
"aten::bitwise_xor.Scalar",
1067+
"aten::bitwise_xor.Scalar_Tensor",
1068+
)
1069+
)
10521070
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
10531071
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1072+
# logical_xor implements the BOOL variant
10541073

10551074
return op.BitwiseXor(self, other)
10561075

@@ -3734,28 +3753,52 @@ def aten_logdet(self: TFloat) -> TFloat:
37343753
return op.Log(op.Det(self))
37353754

37363755

3737-
@torch_op("aten::logical_and")
3756+
@torch_op(
3757+
(
3758+
"aten::logical_and",
3759+
"aten::bitwise_and",
3760+
"aten::bitwise_and.Tensor",
3761+
"aten::bitwise_and.Scalar",
3762+
"aten::bitwise_and.Scalar_Tensor",
3763+
)
3764+
)
37383765
def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
37393766
"""logical_and(Tensor self, Tensor other) -> Tensor"""
37403767

37413768
return op.And(self, other)
37423769

37433770

3744-
@torch_op("aten::logical_not")
3771+
@torch_op(("aten::logical_not", "aten::bitwise_not"))
37453772
def aten_logical_not(self: BOOL) -> BOOL:
37463773
"""logical_not(Tensor self) -> Tensor"""
37473774

37483775
return op.Not(self)
37493776

37503777

3751-
@torch_op("aten::logical_or")
3778+
@torch_op(
3779+
(
3780+
"aten::logical_or",
3781+
"aten::bitwise_or",
3782+
"aten::bitwise_or.Tensor",
3783+
"aten::bitwise_or.Scalar",
3784+
"aten::bitwise_or.Scalar_Tensor",
3785+
)
3786+
)
37523787
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
37533788
"""logical_or(Tensor self, Tensor other) -> Tensor"""
37543789

37553790
return op.Or(self, other)
37563791

37573792

3758-
@torch_op("aten::logical_xor")
3793+
@torch_op(
3794+
(
3795+
"aten::logical_xor",
3796+
"aten::bitwise_xor",
3797+
"aten::bitwise_xor.Tensor",
3798+
"aten::bitwise_xor.Scalar",
3799+
"aten::bitwise_xor.Scalar_Tensor",
3800+
)
3801+
)
37593802
def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
37603803
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
37613804

0 commit comments

Comments
 (0)