Skip to content

Commit aaceba9

Browse files
Copilotgramalingam
andcommitted
Use onnxscript.script with custom Opset in domain_option.py example
Co-authored-by: gramalingam <[email protected]>
1 parent 9c2d145 commit aaceba9

File tree

1 file changed

+12
-37
lines changed

1 file changed

+12
-37
lines changed

docs/tutorial/rewriter/examples/domain_option.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,47 +9,22 @@
99
import onnx
1010

1111
import onnxscript
12+
from onnxscript import script
1213
from onnxscript.rewriter import pattern
14+
from onnxscript.values import Opset
1315

16+
# Create an opset for the custom domain
17+
opset = Opset("custom.domain", 1)
1418

15-
def create_model_with_custom_domain():
19+
20+
@script(opset)
21+
def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]:
1622
"""Create a model with a Relu operation in a custom domain."""
17-
import onnx
18-
from onnx import helper, TensorProto
19-
20-
# Create input
21-
input_tensor = helper.make_tensor_value_info('A', TensorProto.FLOAT, [2, 2])
22-
23-
# Create output
24-
output_tensor = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 2])
25-
26-
# Create Relu node with custom domain
27-
relu_node = helper.make_node(
28-
'Relu',
29-
inputs=['A'],
30-
outputs=['result'],
31-
domain='custom.domain' # Set the custom domain
32-
)
33-
34-
# Create the graph
35-
graph = helper.make_graph(
36-
[relu_node], # nodes
37-
'custom_domain_model', # name
38-
[input_tensor], # inputs
39-
[output_tensor] # outputs
40-
)
41-
42-
# Create the model with opset for custom domain
43-
opset_imports = [
44-
helper.make_opsetid("", 18), # Standard ONNX opset
45-
helper.make_opsetid("custom.domain", 1) # Custom domain opset
46-
]
47-
48-
model = helper.make_model(graph, opset_imports=opset_imports)
49-
return model
50-
51-
52-
_model = create_model_with_custom_domain()
23+
return opset.Relu(input)
24+
25+
26+
_model = create_model_with_custom_domain.to_model_proto()
27+
_model = onnx.shape_inference.infer_shapes(_model)
5328
onnx.checker.check_model(_model)
5429

5530

0 commit comments

Comments
 (0)