|
9 | 9 | import onnx
|
10 | 10 |
|
11 | 11 | import onnxscript
|
| 12 | +from onnxscript import script |
12 | 13 | from onnxscript.rewriter import pattern
|
| 14 | +from onnxscript.values import Opset |
13 | 15 |
|
| 16 | +# Create an opset for the custom domain |
| 17 | +opset = Opset("custom.domain", 1) |
14 | 18 |
|
15 |
| -def create_model_with_custom_domain(): |
| 19 | + |
| 20 | +@script(opset) |
| 21 | +def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: |
16 | 22 | """Create a model with a Relu operation in a custom domain."""
|
17 |
| - import onnx |
18 |
| - from onnx import helper, TensorProto |
19 |
| - |
20 |
| - # Create input |
21 |
| - input_tensor = helper.make_tensor_value_info('A', TensorProto.FLOAT, [2, 2]) |
22 |
| - |
23 |
| - # Create output |
24 |
| - output_tensor = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 2]) |
25 |
| - |
26 |
| - # Create Relu node with custom domain |
27 |
| - relu_node = helper.make_node( |
28 |
| - 'Relu', |
29 |
| - inputs=['A'], |
30 |
| - outputs=['result'], |
31 |
| - domain='custom.domain' # Set the custom domain |
32 |
| - ) |
33 |
| - |
34 |
| - # Create the graph |
35 |
| - graph = helper.make_graph( |
36 |
| - [relu_node], # nodes |
37 |
| - 'custom_domain_model', # name |
38 |
| - [input_tensor], # inputs |
39 |
| - [output_tensor] # outputs |
40 |
| - ) |
41 |
| - |
42 |
| - # Create the model with opset for custom domain |
43 |
| - opset_imports = [ |
44 |
| - helper.make_opsetid("", 18), # Standard ONNX opset |
45 |
| - helper.make_opsetid("custom.domain", 1) # Custom domain opset |
46 |
| - ] |
47 |
| - |
48 |
| - model = helper.make_model(graph, opset_imports=opset_imports) |
49 |
| - return model |
50 |
| - |
51 |
| - |
52 |
| -_model = create_model_with_custom_domain() |
| 23 | + return opset.Relu(input) |
| 24 | + |
| 25 | + |
| 26 | +_model = create_model_with_custom_domain.to_model_proto() |
| 27 | +_model = onnx.shape_inference.infer_shapes(_model) |
53 | 28 | onnx.checker.check_model(_model)
|
54 | 29 |
|
55 | 30 |
|
|
0 commit comments