Skip to content

Commit 691952c

Browse files
fix lint
1 parent b008c92 commit 691952c

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import onnxscript.ir as ir
88
from onnxscript.rewriter import pattern
99

10-
1110
Dim = Union[int, ir.SymbolicDim]
1211

1312

@@ -43,9 +42,7 @@ def pattern(
4342
attention_bias,
4443
num_heads,
4544
scale,
46-
**_,
4745
):
48-
4946
projected = op.MatMul(input, qkv_weight)
5047
# Add bias if present
5148
if self._has_input_bias:
@@ -71,7 +68,7 @@ def pattern(
7168
_allow_other_attributes=True,
7269
_outputs=["value_mm_sliced"],
7370
)
74-
71+
7572
# Split past into past_key and past_value
7673
# past_key and past_value are of shape (B, H, S, D/H)
7774
past_key, past_value = op.Split(past, axis=0, split=[1, 1])
@@ -87,7 +84,7 @@ def pattern(
8784
past_key,
8885
past_value,
8986
num_heads=num_heads,
90-
scale=scale,
87+
scale=scale,
9188
_domain="com.microsoft",
9289
_outputs=3,
9390
)
@@ -124,7 +121,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
124121
return False
125122
if no_match(value_mm_sliced, ["B", "S", "Dh_v"]):
126123
return False
127-
124+
128125
# Ensure Dh = Dh_q + Dh_k + Dh_v
129126
Dh = bindings.get("Dh")
130127
Dh_q = bindings.get("Dh_q")
@@ -169,9 +166,7 @@ def rewrite(
169166
)
170167

171168

172-
attention_with_input_bias_rule = Attention.rule(
173-
"attention_input_bias", has_input_bias=True
174-
)
169+
attention_with_input_bias_rule = Attention.rule("attention_input_bias", has_input_bias=True)
175170
attention_with_no_input_bias_rule = Attention.rule(
176171
"attention_no_input_bias", has_input_bias=False
177172
)

onnxscript/rewriter/ort_fusions/attention_basic_test.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import math
65
import unittest
76

87
import numpy as np
9-
import onnx
108
import onnxruntime as ort
11-
import torch
129

1310
import onnxscript
1411
from onnxscript import FLOAT, script
@@ -17,15 +14,15 @@
1714

1815
msft_op = onnxscript.values.Opset("com.microsoft", 1)
1916

20-
# This is a basic test that verifies that a
17+
# This is a basic test that verifies that a
2118
# proposed expanded computation using packed matmul and ORT's MHA
2219
# is equivalent to ORT's Attention (for the specific configuration considered).
2320

2421
# Simple Attention: no rotary embedding, no past key/value, no cos/sin cache
2522

2623

2724
class AttentionEquivalence(unittest.TestCase):
28-
def __init__(self, *args, **kwargs):
25+
def __init__(self, *args, **kwargs):
2926
super().__init__(*args, **kwargs)
3027
self.batchsize = 2
3128
self.seqlen = 8
@@ -35,7 +32,7 @@ def __init__(self, *args, **kwargs):
3532
self.q_hidden_size = 160
3633
self.k_hidden_size = 160
3734
self.v_hidden_size = 180
38-
#self.num_groups = self.num_heads // self.kv_num_heads
35+
# self.num_groups = self.num_heads // self.kv_num_heads
3936

4037
def random_inputs(self):
4138
B = self.batchsize
@@ -72,6 +69,7 @@ def expanded_model_script(self):
7269
Dh_q = self.q_hidden_size
7370
Dh_qk = self.q_hidden_size + self.k_hidden_size
7471
Dh_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size
72+
7573
@script()
7674
def attention(input, weight, bias):
7775
QKV_no_bias = op.MatMul(input, weight)
@@ -96,9 +94,7 @@ def to_proto(self, model_script):
9694
D_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size
9795
return model_script.to_model_proto(
9896
input_types=(FLOAT["B", "S", D], FLOAT[D, D_qkv], FLOAT[D_qkv]),
99-
output_types=(
100-
FLOAT["B", "S", self.v_hidden_size],
101-
),
97+
output_types=(FLOAT["B", "S", self.v_hidden_size],),
10298
)
10399

104100
def test_equivalence(self):

0 commit comments

Comments
 (0)