Skip to content

Introduce pattern.any_value #2175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,17 @@ def pattern(
shape_B111,
):
# Reshape query from (B, S, D) to (B, S, H, D/H)
query_BSHDh = op.Reshape(query_BSD, _allow_other_inputs=True, _outputs=["query_BSHDh"])
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])

# Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
key_BSHkvDh = op.Reshape(key_BSDkv, _allow_other_inputs=True, _outputs=["key_BSHkvDh"])
key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"])
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])

# Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H)
value_BSHkvDh = op.Reshape(
value_BSDkv, _allow_other_inputs=True, _outputs=["value_BSHkvDh"]
)
value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"])
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])

Expand Down Expand Up @@ -129,18 +127,18 @@ def pattern(

key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, _allow_other_inputs=True)
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
key_seq_BHTDh = op.Reshape(
key_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["key_seq_BHTDh"]
key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"]
)

# Concatenate past_value cache and current value, expand across heads
# that share key/value.
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, _allow_other_inputs=True)
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
value_seq_BHTDh = op.Reshape(
value_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["value_seq_BHTDh"]
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]
)

mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
Expand All @@ -158,7 +156,7 @@ def pattern(
attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3])
# Reshape back to (B, S, D)
attention_BSD = op.Reshape(
attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"]
attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"]
)
return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh

Expand Down
17 changes: 17 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,20 @@ def _is_pattern_variable(x: Any) -> bool:
return type(x) is ValuePattern


class AnyValue(ValuePattern):
"""Represents a pattern that matches against any value."""

def __init__(self) -> None:
super().__init__(None)

def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue:
# A single instance of AnyValue suffices.
return self


ANY_VALUE = AnyValue()


class Constant(ValuePattern):
"""Represents a pattern that matches against a scalar constant value."""

Expand Down Expand Up @@ -1108,6 +1122,9 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bo

def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool:
"""Match an IR value against a ValuePattern instance."""
if isinstance(pattern_value, AnyValue):
return True

if not self._bind_value(pattern_value, value):
return False

Expand Down
21 changes: 21 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,27 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]:
onnxscript.optimizer.inline(model)
self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul", "Add", "Mul"])

def test_any_value(self):
def source_pattern(op, x):
return op.Add(x, op.Mul(0, pattern.ANY_VALUE))

def replacement(op, x):
return op.Identity(x)

rule = pattern.RewriteRule(source_pattern, replacement)

@script()
def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
zero = op.Constant(value_float=0.0)
return op.Add(x, op.Mul(zero, y))

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Mul", "Add"])
rule.apply_to_model(model)
self.assertEqual(len(model.graph), 2)
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"])


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading