-
Notifications
You must be signed in to change notification settings - Fork 72
Merge 'initializers' into 'TorchScriptGraph' #480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ffed433
5598880
187b07f
c78058b
dcca04f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,23 +17,22 @@ | |
@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported") | ||
class TestTorchScriptTracingEvaluator(unittest.TestCase): | ||
def setUp(self): | ||
# FIXME: Currently this must match with the import line | ||
# `from onnxscript import opset17 as op`, which restricts opset to be 17 in these | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model opset can actually be different from the local functions' opsets I think. Do you know if that's true? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea it could be different under constraint.
|
||
# tests anyways. | ||
self.opset_version = 17 | ||
# TODO: Add test for initializer. Currently skipped since to `assert_isomorphic` | ||
# does not check for initializers. | ||
self.onnxscript_graph = graph_building.TorchScriptGraph() | ||
self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph) | ||
|
||
def to_model_proto(self): | ||
# TODO(titaiwang): initializer API | ||
return self.onnxscript_graph.to_model_proto( | ||
initializers={}, opset_version=self.opset_version | ||
) | ||
|
||
def test_traced_constant_op_is_same_as_compiled_graph(self): | ||
"""Test for op.Constant created in graph builder""" | ||
with evaluator.default_as(self.tracer): | ||
output = op.Constant(value_float=0.5) | ||
|
||
self.onnxscript_graph.register_outputs(output) | ||
traced = self.to_model_proto() | ||
traced = self.onnxscript_graph.to_model_proto(self.opset_version) | ||
|
||
@onnxscript.script() | ||
def expected_model(): | ||
|
@@ -51,7 +50,7 @@ def test_traced_graph_on_single_node_is_same_as_compiled_graph(self): | |
output = aten_relu(x) | ||
|
||
self.onnxscript_graph.register_outputs(output) | ||
traced = self.to_model_proto() | ||
traced = self.onnxscript_graph.to_model_proto(self.opset_version) | ||
|
||
@onnxscript.script(default_opset=op) | ||
def expected_model(x: FLOAT[1, 2, 3]): | ||
|
@@ -70,7 +69,7 @@ def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self | |
output = aten_topk(x, 2) | ||
|
||
self.onnxscript_graph.register_outputs(output) | ||
traced = self.to_model_proto() | ||
traced = self.onnxscript_graph.to_model_proto(self.opset_version) | ||
|
||
@onnxscript.script(default_opset=op) | ||
def expected_model(x: FLOAT[1, 2, 3]): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a brief docstring for this property help?