Skip to content

Commit 5575c01

Browse files
authored
[torchlib] Register aten::__lshift__ and __rshift__ (#2102)
Tested with ```py import math import torch class Gray(torch.nn.Module): nbits: int = 32 def forward(self, gray: torch.Tensor): shifts = [(0x1 << i) for i in range((math.ceil(math.log(self.nbits, 2)) - 1), -1, -1)] for shift in shifts: gray ^= gray >> shift return gray onnx_program = torch.onnx.export( Gray(), # model to export (torch.randint(0, 100, [100], dtype=torch.long)), # inputs of the model, dynamo=True, # True or False to select the exporter to use, ) print(onnx_program) ``` Fixes pytorch/pytorch#149083
1 parent 1da3b9c commit 5575c01

File tree

1 file changed

+8
-0
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+8
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,7 @@ def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
12431243
"aten::bitwise_left_shift.Tensor_Scalar",
12441244
"aten::bitwise_left_shift.Scalar_Tensor",
12451245
"_operator::__lshift__",
1246+
"aten::__lshift__.Scalar",
12461247
),
12471248
trace_only=True,
12481249
)
@@ -1263,6 +1264,7 @@ def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16:
12631264
"aten::bitwise_left_shift.Tensor_Scalar",
12641265
"aten::bitwise_left_shift.Scalar_Tensor",
12651266
"_operator::__lshift__",
1267+
"aten::__lshift__.Scalar",
12661268
),
12671269
trace_only=True,
12681270
)
@@ -1283,6 +1285,7 @@ def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32:
12831285
"aten::bitwise_left_shift.Tensor_Scalar",
12841286
"aten::bitwise_left_shift.Scalar_Tensor",
12851287
"_operator::__lshift__",
1288+
"aten::__lshift__.Scalar",
12861289
),
12871290
trace_only=True,
12881291
)
@@ -1303,6 +1306,7 @@ def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64:
13031306
"aten::bitwise_left_shift.Tensor_Scalar",
13041307
"aten::bitwise_left_shift.Scalar_Tensor",
13051308
"_operator::__lshift__",
1309+
"aten::__lshift__.Scalar",
13061310
),
13071311
trace_only=True,
13081312
)
@@ -1347,6 +1351,7 @@ def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
13471351
"aten::bitwise_right_shift.Tensor_Scalar",
13481352
"aten::bitwise_right_shift.Scalar_Tensor",
13491353
"_operator::__rshift__",
1354+
"aten::__rshift__.Scalar",
13501355
)
13511356
)
13521357
def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16:
@@ -1377,6 +1382,7 @@ def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16:
13771382
"aten::bitwise_right_shift.Tensor_Scalar",
13781383
"aten::bitwise_right_shift.Scalar_Tensor",
13791384
"_operator::__rshift__",
1385+
"aten::__rshift__.Scalar",
13801386
)
13811387
)
13821388
def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32:
@@ -1407,6 +1413,7 @@ def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32:
14071413
"aten::bitwise_right_shift.Tensor_Scalar",
14081414
"aten::bitwise_right_shift.Scalar_Tensor",
14091415
"_operator::__rshift__",
1416+
"aten::__rshift__.Scalar",
14101417
)
14111418
)
14121419
def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64:
@@ -1440,6 +1447,7 @@ def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64:
14401447
"aten::bitwise_right_shift.Tensor_Scalar",
14411448
"aten::bitwise_right_shift.Scalar_Tensor",
14421449
"_operator::__rshift__",
1450+
"aten::__rshift__.Scalar",
14431451
)
14441452
)
14451453
def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:

0 commit comments

Comments
 (0)