1717@unittest .skipIf (version_utils .torch_older_than ("2.0" ), "torchscript in 1.13 not supported" )
1818class TestTorchScriptTracingEvaluator (unittest .TestCase ):
1919 def setUp (self ):
20+ # FIXME: Currently this must match with the import line
21+ # `from onnxscript import opset17 as op`, which restricts opset to be 17 in these
22+ # tests anyways.
2023 self .opset_version = 17
24+ # TODO: Add test for initializer. Currently skipped since to `assert_isomorphic`
25+ # does not check for initializers.
2126 self .onnxscript_graph = graph_building .TorchScriptGraph ()
2227 self .tracer = graph_building .TorchScriptTracingEvaluator (self .onnxscript_graph )
2328
24- def to_model_proto (self ):
25- # TODO(titaiwang): initializer API
26- return self .onnxscript_graph .to_model_proto (
27- initializers = {}, opset_version = self .opset_version
28- )
29-
3029 def test_traced_constant_op_is_same_as_compiled_graph (self ):
3130 """Test for op.Constant created in graph builder"""
3231 with evaluator .default_as (self .tracer ):
3332 output = op .Constant (value_float = 0.5 )
3433
3534 self .onnxscript_graph .register_outputs (output )
36- traced = self .to_model_proto ()
35+ traced = self .onnxscript_graph . to_model_proto (self . opset_version )
3736
3837 @onnxscript .script ()
3938 def expected_model ():
@@ -51,7 +50,7 @@ def test_traced_graph_on_single_node_is_same_as_compiled_graph(self):
5150 output = aten_relu (x )
5251
5352 self .onnxscript_graph .register_outputs (output )
54- traced = self .to_model_proto ()
53+ traced = self .onnxscript_graph . to_model_proto (self . opset_version )
5554
5655 @onnxscript .script (default_opset = op )
5756 def expected_model (x : FLOAT [1 , 2 , 3 ]):
@@ -70,7 +69,7 @@ def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self
7069 output = aten_topk (x , 2 )
7170
7271 self .onnxscript_graph .register_outputs (output )
73- traced = self .to_model_proto ()
72+ traced = self .onnxscript_graph . to_model_proto (self . opset_version )
7473
7574 @onnxscript .script (default_opset = op )
7675 def expected_model (x : FLOAT [1 , 2 , 3 ]):
0 commit comments