From 19bf7427ab90fbc0b772e7f7a372384294471b18 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 14 Mar 2025 16:19:10 +0100 Subject: [PATCH 1/7] Make test test_smollm 20% faster --- onnxscript/rewriter/ort_fusions/_test_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index e4eba174fb..99151a0163 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -27,13 +27,13 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] - with tempfile.TemporaryDirectory() as temp_dir: - model_path = os.path.join(temp_dir, f"{model_name}.onnx") - _save(model, model_path) - # Run model - session = onnxruntime.InferenceSession(model_path, providers=providers) - ort_outputs = session.run(None, inputs) - return ort_outputs + onx = ir.serde.serialize_model(model) + opts = onnxruntime.SessionOptions() + opts.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + ) + session = onnxruntime.InferenceSession(onx.SerializeToString(), opts, providers=providers) + return session.run(None, inputs) def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2): From 77957f83b99b65083a0816d2cd4be45f87f564ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Mar 2025 16:56:18 +0100 Subject: [PATCH 2/7] Update onnxscript/rewriter/ort_fusions/_test_utils.py Co-authored-by: Justin Chu --- onnxscript/rewriter/ort_fusions/_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 99151a0163..91f391d2b9 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -27,7 +27,7 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] - onx = ir.serde.serialize_model(model) + model_proto = ir.serde.serialize_model(model) opts = onnxruntime.SessionOptions() opts.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL From 81fbf3216e62263317e31fd6ee9c1a344d7e2ca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Mar 2025 16:56:23 +0100 Subject: [PATCH 3/7] Update onnxscript/rewriter/ort_fusions/_test_utils.py Co-authored-by: Justin Chu --- onnxscript/rewriter/ort_fusions/_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 91f391d2b9..28c4c4ce3c 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -28,7 +28,7 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] model_proto = ir.serde.serialize_model(model) - opts = onnxruntime.SessionOptions() + options = onnxruntime.SessionOptions() opts.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ) From d0da6756b18bd0b3aa507b2e9dbed6ba52118954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Mar 2025 16:56:31 +0100 Subject: [PATCH 4/7] Update onnxscript/rewriter/ort_fusions/_test_utils.py Co-authored-by: Justin Chu --- onnxscript/rewriter/ort_fusions/_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 28c4c4ce3c..25a41a8ced 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -32,7 +32,7 @@ def ort_run(model_name: str, model, inputs): opts.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ) - session = onnxruntime.InferenceSession(onx.SerializeToString(), opts, providers=providers) + session = onnxruntime.InferenceSession(model_proto.SerializeToString(), options, providers=providers) return session.run(None, inputs) From 7e243d195b671f685466c171935b49e4857b65bd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Mar 2025 09:02:25 -0700 Subject: [PATCH 5/7] Update onnxscript/rewriter/ort_fusions/_test_utils.py --- onnxscript/rewriter/ort_fusions/_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 25a41a8ced..cc94aa3c49 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -29,7 +29,7 @@ def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] model_proto = ir.serde.serialize_model(model) options = onnxruntime.SessionOptions() - opts.graph_optimization_level = ( + options.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ) session = onnxruntime.InferenceSession(model_proto.SerializeToString(), options, providers=providers) From a790a7539a578f1bffd745fc7ce49b7f74604661 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 14 Mar 2025 19:21:51 +0100 Subject: [PATCH 6/7] lint --- onnxscript/rewriter/ort_fusions/_test_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index cc94aa3c49..8b639c5112 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import tempfile import numpy as np import onnx @@ -29,10 +28,10 @@ def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] model_proto = ir.serde.serialize_model(model) options = onnxruntime.SessionOptions() - options.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + session = onnxruntime.InferenceSession( + model_proto.SerializeToString(), options, providers=providers ) - session = onnxruntime.InferenceSession(model_proto.SerializeToString(), options, providers=providers) return session.run(None, inputs) From f1d3ff385bf89b52c72a296a4adfc442bdebb259 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 15 Mar 2025 14:57:16 +0100 Subject: [PATCH 7/7] lint --- onnxscript/rewriter/ort_fusions/_test_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 8b639c5112..12bdcf2d4d 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. from __future__ import annotations -import os - import numpy as np import onnx import onnxruntime