Skip to content

Commit 58ba823

Browse files
authored
Merge branch 'main' into xiaowu/addOps(0216)
2 parents 6be5eb3 + e66ca88 commit 58ba823

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

onnxscript/function_libs/torch_aten/graph_building.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,16 @@ def __init__(self):
296296
# All the functions used, deduplicated by name
297297
# key: (name, domain)
298298
self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {}
299+
self._initializers: Dict[str, torch.Tensor] = {}
299300

300301
@property
301302
def torch_graph(self):
302303
return self._torch_graph
303304

305+
@property
306+
def initializers(self) -> Mapping[str, torch.Tensor]:
307+
return self._initializers
308+
304309
@beartype
305310
def add_input(
306311
self, input_name: str, input_value: Optional[torch.Tensor] = None
@@ -327,6 +332,10 @@ def add_input(
327332
tensor_value = _wrap_torch_value_to_tensor(torch_value)
328333
return tensor_value # type: ignore[return-value]
329334

335+
@beartype
336+
def add_initializer(self, input_name: str, input_value: torch.Tensor) -> None:
337+
self._initializers[input_name] = input_value
338+
330339
@beartype
331340
def register_outputs(
332341
self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]
@@ -446,16 +455,14 @@ def add_function_call(
446455
return result
447456

448457
@beartype
449-
def to_model_proto(
450-
self, initializers: Mapping[str, torch.Tensor], opset_version: Optional[int]
451-
) -> onnx.ModelProto:
458+
def to_model_proto(self, opset_version: int) -> onnx.ModelProto:
452459
(
453460
proto,
454461
_,
455462
_,
456463
_,
457464
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
458-
initializers=initializers,
465+
initializers=self.initializers,
459466
onnx_opset_version=opset_version,
460467
# TODO(justinchuby): Figure out how to get the dynamic axes from the inputs
461468
dynamic_axes={},

onnxscript/function_libs/torch_aten/graph_building_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,22 @@
1717
@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported")
1818
class TestTorchScriptTracingEvaluator(unittest.TestCase):
1919
def setUp(self):
20+
# FIXME: Currently this must match with the import line
21+
# `from onnxscript import opset17 as op`, which restricts opset to be 17 in these
22+
# tests anyways.
2023
self.opset_version = 17
24+
# TODO: Add test for initializer. Currently skipped since to `assert_isomorphic`
25+
# does not check for initializers.
2126
self.onnxscript_graph = graph_building.TorchScriptGraph()
2227
self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph)
2328

24-
def to_model_proto(self):
25-
# TODO(titaiwang): initializer API
26-
return self.onnxscript_graph.to_model_proto(
27-
initializers={}, opset_version=self.opset_version
28-
)
29-
3029
def test_traced_constant_op_is_same_as_compiled_graph(self):
3130
"""Test for op.Constant created in graph builder"""
3231
with evaluator.default_as(self.tracer):
3332
output = op.Constant(value_float=0.5)
3433

3534
self.onnxscript_graph.register_outputs(output)
36-
traced = self.to_model_proto()
35+
traced = self.onnxscript_graph.to_model_proto(self.opset_version)
3736

3837
@onnxscript.script()
3938
def expected_model():
@@ -51,7 +50,7 @@ def test_traced_graph_on_single_node_is_same_as_compiled_graph(self):
5150
output = aten_relu(x)
5251

5352
self.onnxscript_graph.register_outputs(output)
54-
traced = self.to_model_proto()
53+
traced = self.onnxscript_graph.to_model_proto(self.opset_version)
5554

5655
@onnxscript.script(default_opset=op)
5756
def expected_model(x: FLOAT[1, 2, 3]):
@@ -70,7 +69,7 @@ def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self
7069
output = aten_topk(x, 2)
7170

7271
self.onnxscript_graph.register_outputs(output)
73-
traced = self.to_model_proto()
72+
traced = self.onnxscript_graph.to_model_proto(self.opset_version)
7473

7574
@onnxscript.script(default_opset=op)
7675
def expected_model(x: FLOAT[1, 2, 3]):

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,8 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor:
20392039
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
20402040

20412041
size = op.Cast(size, to=INT64.dtype)
2042+
# To support -1 dim.
2043+
size = op.Abs(size)
20422044
return op.Expand(self, size)
20432045

20442046

0 commit comments

Comments
 (0)