From 6ca05676b02806be6edf45f4cfde9e4fb92536d6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 9 Apr 2025 13:24:10 -0700 Subject: [PATCH 1/4] Add support for pattern.any_value --- onnxscript/rewriter/pattern.py | 17 +++++++++++++++++ onnxscript/rewriter/pattern_test.py | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 907ebd0b88..307c364ffd 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -634,6 +634,20 @@ def _is_pattern_variable(x: Any) -> bool: return type(x) is ValuePattern +class AnyValue(ValuePattern): + """Represents a pattern that matches against any value.""" + + def __init__(self) -> None: + super().__init__(None) + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue: + # A single instance of AnyValue suffices. + return self + + +any_value = AnyValue() + + class Constant(ValuePattern): """Represents a pattern that matches against a scalar constant value.""" @@ -1108,6 +1122,9 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bo def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool: """Match an IR value against a ValuePattern instance.""" + if isinstance(pattern_value, AnyValue): + return True + if not self._bind_value(pattern_value, value): return False diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 24ae237c20..7fdcefe86c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -667,6 +667,27 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul", "Add", "Mul"]) + def test_any_value(self): + def source_pattern(op, x): + return op.Add(x, op.Mul(0, pattern.any_value)) + + def replacement(op, x): + return op.Identity(x) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: + zero = op.Constant(value_float=0.0) + return op.Add(x, op.Mul(zero, y)) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.assertEqual([x.op_type for x in model.graph], ["Constant", "Mul", "Add"]) + rule.apply_to_model(model) + self.assertEqual(len(model.graph), 2) + self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 62e05a2ebbffa2305533f4c0a6cb771d5f390db2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 9 Apr 2025 15:10:17 -0700 Subject: [PATCH 2/4] Use pattern.any_value --- onnxscript/rewriter/ort_fusions/gqa.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 7f761a3744..22db6f98bd 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -87,19 +87,17 @@ def pattern( shape_B111, ): # Reshape query from (B, S, D) to (B, S, H, D/H) - query_BSHDh = op.Reshape(query_BSD, _allow_other_inputs=True, _outputs=["query_BSHDh"]) + query_BSHDh = op.Reshape(query_BSD, pattern.any_value, _outputs=["query_BSHDh"]) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) - key_BSHkvDh = op.Reshape(key_BSDkv, _allow_other_inputs=True, _outputs=["key_BSHkvDh"]) + key_BSHkvDh = op.Reshape(key_BSDkv, pattern.any_value, _outputs=["key_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) - value_BSHkvDh = op.Reshape( - value_BSDkv, _allow_other_inputs=True, _outputs=["value_BSHkvDh"] - ) + value_BSHkvDh = op.Reshape(value_BSDkv, pattern.any_value, _outputs=["value_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -129,18 +127,18 @@ def pattern( key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) - key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, _allow_other_inputs=True) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.any_value) key_seq_BHTDh = op.Reshape( - key_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["key_seq_BHTDh"] + key_seq_BHkvGTDh, pattern.any_value, _outputs=["key_seq_BHTDh"] ) # Concatenate past_value cache and current value, expand across heads # that share key/value. value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) - value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, _allow_other_inputs=True) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.any_value) value_seq_BHTDh = op.Reshape( - value_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["value_seq_BHTDh"] + value_seq_BHkvGTDh, pattern.any_value, _outputs=["value_seq_BHTDh"] ) mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) @@ -158,7 +156,7 @@ def pattern( attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) attention_BSD = op.Reshape( - attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] + attention_BSHDh, pattern.any_value, _outputs=["attention_BSD"] ) return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh From aa1300238c97ee37f844bf1809e9fa802ef5ceda Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 10 Apr 2025 09:11:32 -0700 Subject: [PATCH 3/4] Change any_value to ANY_VALUE --- onnxscript/rewriter/ort_fusions/gqa.py | 16 ++++++++-------- onnxscript/rewriter/pattern.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 22db6f98bd..266987dd4d 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -87,17 +87,17 @@ def pattern( shape_B111, ): # Reshape query from (B, S, D) to (B, S, H, D/H) - query_BSHDh = op.Reshape(query_BSD, pattern.any_value, _outputs=["query_BSHDh"]) + query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) - key_BSHkvDh = op.Reshape(key_BSDkv, pattern.any_value, _outputs=["key_BSHkvDh"]) + key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) - value_BSHkvDh = op.Reshape(value_BSDkv, pattern.any_value, _outputs=["value_BSHkvDh"]) + value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -127,18 +127,18 @@ def pattern( key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) - key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.any_value) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) key_seq_BHTDh = op.Reshape( - key_seq_BHkvGTDh, pattern.any_value, _outputs=["key_seq_BHTDh"] + key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] ) # Concatenate past_value cache and current value, expand across heads # that share key/value. value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) - value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.any_value) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) value_seq_BHTDh = op.Reshape( - value_seq_BHkvGTDh, pattern.any_value, _outputs=["value_seq_BHTDh"] + value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] ) mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) @@ -156,7 +156,7 @@ def pattern( attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) attention_BSD = op.Reshape( - attention_BSHDh, pattern.any_value, _outputs=["attention_BSD"] + attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"] ) return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 307c364ffd..cfca31125f 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -645,7 +645,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue: return self -any_value = AnyValue() +ANY_VALUE = AnyValue() class Constant(ValuePattern): From 9e079d4e97d936ee7d061b9ad7833c15e5fc59be Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 10 Apr 2025 09:39:21 -0700 Subject: [PATCH 4/4] Update test also --- onnxscript/rewriter/pattern_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 7fdcefe86c..ce11e23c19 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -669,7 +669,7 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: def test_any_value(self): def source_pattern(op, x): - return op.Add(x, op.Mul(0, pattern.any_value)) + return op.Add(x, op.Mul(0, pattern.ANY_VALUE)) def replacement(op, x): return op.Identity(x)