Skip to content

Commit 9c2342d

Browse files
add utilities
1 parent c0bb9aa commit 9c2342d

File tree

4 files changed

+46
-46
lines changed

4 files changed

+46
-46
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Sequence, Union
6+
7+
8+
from onnxscript import ir
9+
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
13+
if val.shape is None:
14+
return False
15+
if val.shape.rank() != len(shape):
16+
return False
17+
for actual, expected in zip(val.shape, shape):
18+
if expected not in bindings:
19+
bindings[expected] = actual # type: ignore[assignment]
20+
elif actual != bindings[expected]:
21+
return False
22+
return True

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,12 @@
55
from typing import Sequence, Union
66

77
import onnxscript.ir as ir
8-
from onnxscript.rewriter import pattern
8+
from onnxscript.rewriter import _fusion_utils, pattern
99

1010
Dim = Union[int, ir.SymbolicDim]
1111

1212

1313
# TODO: Maybe add this check to utilities
14-
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
15-
if val.shape is None:
16-
return False
17-
if val.shape.rank() != len(shape):
18-
return False
19-
for actual, expected in zip(val.shape, shape):
20-
if expected not in bindings:
21-
bindings[expected] = actual # type: ignore[assignment]
22-
elif actual != bindings[expected]:
23-
return False
24-
return True
2514

2615

2716
class AttentionFusion(pattern.RewriteRuleClassBase):
@@ -103,6 +92,7 @@ def pattern(
10392
present_key = op.Unsqueeze(present_key, [0])
10493
present_value = op.Unsqueeze(present_value, [0])
10594
present = op.Concat(present_key, present_value, axis=0)
95+
# Return present output first as it captures the complete pattern graph
10696
return present, attention
10797
else:
10898
attention = op.MultiHeadAttention(
@@ -136,7 +126,7 @@ def check(
136126
self.bindings: dict[str, Dim] = {}
137127

138128
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
139-
return not _check_shape(self.bindings, val, dims)
129+
return not _fusion_utils._check_shape(self.bindings, val, dims)
140130

141131
if no_match(input, ["B", "S", "D"]):
142132
return check_result.fail(
@@ -228,7 +218,7 @@ def rewrite(
228218
_domain="com.microsoft",
229219
_outputs=2,
230220
)
231-
# Return present output first as it captures the complete rewrite pattern graph
221+
# Use same output ordering as in pattern
232222
return present, attention
233223
else:
234224
return op.Attention(

onnxscript/rewriter/ort_fusions/attention_test.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs):
2424
super().__init__(*args, **kwargs)
2525
self.batchsize = 2
2626
self.seqlen = 8
27-
self.max_seqlen = 32
27+
self.past_seqlen = 32
2828
self.headsize = 16
2929
self.num_heads = 10
3030
self.input_hidden_size = self.headsize * self.num_heads
@@ -36,7 +36,7 @@ def random_inputs(self, with_past=False):
3636
"""Generate random inputs for the model."""
3737
B = self.batchsize
3838
S = self.seqlen
39-
M = self.max_seqlen
39+
Sp = self.past_seqlen
4040
D = self.input_hidden_size
4141
N = self.num_heads
4242
H = self.headsize
@@ -48,22 +48,22 @@ def random_inputs(self, with_past=False):
4848
"bias": np.random.rand(D_qkv).astype(np.float32),
4949
}
5050
if with_past:
51-
inputs["past"] = np.random.rand(2, B, N, M, H).astype(np.float32)
51+
inputs["past"] = np.random.rand(2, B, N, Sp, H).astype(np.float32)
5252
return inputs
5353

5454
def create_model(self, with_past=False):
5555
"""Create a model with or without past inputs."""
5656
D = self.input_hidden_size
57-
Dh_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size
57+
D_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size
5858

5959
@script()
6060
def model_with_mha(input, weight, bias):
61-
QKV_no_bias = op.MatMul(input, weight)
62-
QKV = op.Add(QKV_no_bias, bias)
61+
qkv_no_bias = op.MatMul(input, weight)
62+
qkv = op.Add(qkv_no_bias, bias)
6363

64-
query_BSDh = op.Slice(QKV, [0], [160], [2])
65-
key_BSDh = op.Slice(QKV, [160], [320], [2])
66-
value_BSDh = op.Slice(QKV, [320], [480], [2])
64+
query_BSDh = op.Slice(qkv, [0], [160], [2])
65+
key_BSDh = op.Slice(qkv, [160], [320], [2])
66+
value_BSDh = op.Slice(qkv, [320], [480], [2])
6767

6868
mha = msft_op.MultiHeadAttention(
6969
query_BSDh,
@@ -75,12 +75,12 @@ def model_with_mha(input, weight, bias):
7575

7676
@script()
7777
def model_with_mha_past(input, weight, bias, past):
78-
QKV_no_bias = op.MatMul(input, weight)
79-
QKV = op.Add(QKV_no_bias, bias)
78+
qkv_no_bias = op.MatMul(input, weight)
79+
qkv = op.Add(qkv_no_bias, bias)
8080

81-
query_BSDh = op.Slice(QKV, [0], [160], [2])
82-
key_BSDh = op.Slice(QKV, [160], [320], [2])
83-
value_BSDh = op.Slice(QKV, [320], [480], [2])
81+
query_BSDh = op.Slice(qkv, [0], [160], [2])
82+
key_BSDh = op.Slice(qkv, [160], [320], [2])
83+
value_BSDh = op.Slice(qkv, [320], [480], [2])
8484

8585
past_key_5d = op.Slice(past, [0], [1], [0])
8686
past_value_5d = op.Slice(past, [1], [2], [0])
@@ -106,14 +106,15 @@ def model_with_mha_past(input, weight, bias, past):
106106

107107
input_types = (
108108
FLOAT["B", "S", D],
109-
FLOAT[D, Dh_qkv],
110-
FLOAT[Dh_qkv],
109+
FLOAT[D, D_qkv],
110+
FLOAT[D_qkv],
111111
)
112112
output_types = (FLOAT["B", "S", self.v_hidden_size],)
113113

114114
if with_past:
115+
# "T" indicates total sequence length (after concatenation of past and current key/value)
115116
input_types += (FLOAT[2, "B", self.num_heads, "S", self.headsize],)
116-
output_types += (FLOAT[2, "B", self.num_heads, "NS", self.headsize],)
117+
output_types += (FLOAT[2, "B", self.num_heads, "T", self.headsize],)
117118
model_proto = model_with_mha_past.to_model_proto(
118119
input_types=input_types,
119120
output_types=output_types,

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Sequence, Union
66

77
import onnxscript.ir as ir
8-
from onnxscript.rewriter import _ir_utils, pattern
8+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
99

1010
"""
1111
The MultiHeadAttention pattern: generate an instance
@@ -31,19 +31,6 @@
3131
Dim = Union[int, ir.SymbolicDim]
3232

3333

34-
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
35-
if val.shape is None:
36-
return False
37-
if val.shape.rank() != len(shape):
38-
return False
39-
for actual, expected in zip(val.shape, shape):
40-
if expected not in bindings:
41-
bindings[expected] = actual # type: ignore[assignment]
42-
elif actual != bindings[expected]:
43-
return False
44-
return True
45-
46-
4734
class MultiHeadAttention(pattern.RewriteRuleClassBase):
4835
def __init__(self, name, *, transpose_4d: bool):
4936
super().__init__(name)
@@ -168,7 +155,7 @@ def check(
168155
bindings: dict[str, Dim] = {}
169156

170157
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
171-
return not _check_shape(bindings, val, dims)
158+
return not _fusion_utils._check_shape(bindings, val, dims)
172159

173160
if no_match(query_BSD, ["B", "S", "D"]):
174161
return check_result.fail(

0 commit comments

Comments
 (0)