Skip to content

Merge 'initializers' into 'TorchScriptGraph' #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions onnxscript/function_libs/torch_aten/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,16 @@ def __init__(self):
# All the functions used, deduplicated by name
# key: (name, domain)
self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {}
self._initializers: Dict[str, torch.Tensor] = {}

@property
def torch_graph(self):
return self._torch_graph

@property
def initializers(self) -> Mapping[str, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a brief docstring for this property help?

return self._initializers

@beartype
def add_input(
self, input_name: str, input_value: Optional[torch.Tensor] = None
Expand All @@ -327,6 +332,10 @@ def add_input(
tensor_value = _wrap_torch_value_to_tensor(torch_value)
return tensor_value # type: ignore[return-value]

@beartype
def add_initializer(self, input_name: str, input_value: torch.Tensor) -> None:
self._initializers[input_name] = input_value

@beartype
def register_outputs(
self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]
Expand Down Expand Up @@ -446,16 +455,14 @@ def add_function_call(
return result

@beartype
def to_model_proto(
self, initializers: Mapping[str, torch.Tensor], opset_version: Optional[int]
) -> onnx.ModelProto:
def to_model_proto(self, opset_version: int) -> onnx.ModelProto:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers=initializers,
initializers=self.initializers,
onnx_opset_version=opset_version,
# TODO(justinchuby): Figure out how to get the dynamic axes from the inputs
dynamic_axes={},
Expand Down
17 changes: 8 additions & 9 deletions onnxscript/function_libs/torch_aten/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,22 @@
@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported")
class TestTorchScriptTracingEvaluator(unittest.TestCase):
def setUp(self):
# FIXME: Currently this must match with the import line
# `from onnxscript import opset17 as op`, which restricts opset to be 17 in these
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model opset can actually be different from the local functions' opsets I think. Do you know if that's true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it could be different under constraint.
https://github.com/onnx/onnx/blob/4b2d50334914621835cc1e8dadd4fe82b6b9876c/onnx/onnx.in.proto#L824-L828

  // The operator sets imported by FunctionProto should be compatible with the ones
  // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
  // and ModelProto then versions for the operator set may be different but,
  // the operator schema returned for op_type, domain, version combination
  // for both the versions should be same.

# tests anyways.
self.opset_version = 17
# TODO: Add test for initializer. Currently skipped since to `assert_isomorphic`
# does not check for initializers.
self.onnxscript_graph = graph_building.TorchScriptGraph()
self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph)

def to_model_proto(self):
# TODO(titaiwang): initializer API
return self.onnxscript_graph.to_model_proto(
initializers={}, opset_version=self.opset_version
)

def test_traced_constant_op_is_same_as_compiled_graph(self):
"""Test for op.Constant created in graph builder"""
with evaluator.default_as(self.tracer):
output = op.Constant(value_float=0.5)

self.onnxscript_graph.register_outputs(output)
traced = self.to_model_proto()
traced = self.onnxscript_graph.to_model_proto(self.opset_version)

@onnxscript.script()
def expected_model():
Expand All @@ -51,7 +50,7 @@ def test_traced_graph_on_single_node_is_same_as_compiled_graph(self):
output = aten_relu(x)

self.onnxscript_graph.register_outputs(output)
traced = self.to_model_proto()
traced = self.onnxscript_graph.to_model_proto(self.opset_version)

@onnxscript.script(default_opset=op)
def expected_model(x: FLOAT[1, 2, 3]):
Expand All @@ -70,7 +69,7 @@ def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self
output = aten_topk(x, 2)

self.onnxscript_graph.register_outputs(output)
traced = self.to_model_proto()
traced = self.onnxscript_graph.to_model_proto(self.opset_version)

@onnxscript.script(default_opset=op)
def expected_model(x: FLOAT[1, 2, 3]):
Expand Down