2
2
# Licensed under the MIT License.
3
3
from __future__ import annotations
4
4
5
- import math
6
5
import unittest
7
6
8
7
import numpy as np
9
- import onnx
10
8
import onnxruntime as ort
11
- import torch
12
9
13
10
import onnxscript
14
11
from onnxscript import FLOAT , script
17
14
18
15
msft_op = onnxscript .values .Opset ("com.microsoft" , 1 )
19
16
20
- # This is a basic test that verifies that a
17
+ # This is a basic test that verifies that a
21
18
# proposed expanded computation using packed matmul and ORT's MHA
22
19
# is equivalent to ORT's Attention (for the specific configuration considered).
23
20
24
21
# Simple Attention: no rotary embedding, no past key/value, no cos/sin cache
25
22
26
23
27
24
class AttentionEquivalence (unittest .TestCase ):
28
- def __init__ (self , * args , ** kwargs ):
25
+ def __init__ (self , * args , ** kwargs ):
29
26
super ().__init__ (* args , ** kwargs )
30
27
self .batchsize = 2
31
28
self .seqlen = 8
@@ -35,7 +32,7 @@ def __init__(self, *args, **kwargs):
35
32
self .q_hidden_size = 160
36
33
self .k_hidden_size = 160
37
34
self .v_hidden_size = 180
38
- #self.num_groups = self.num_heads // self.kv_num_heads
35
+ # self.num_groups = self.num_heads // self.kv_num_heads
39
36
40
37
def random_inputs (self ):
41
38
B = self .batchsize
@@ -72,6 +69,7 @@ def expanded_model_script(self):
72
69
Dh_q = self .q_hidden_size
73
70
Dh_qk = self .q_hidden_size + self .k_hidden_size
74
71
Dh_qkv = self .q_hidden_size + self .k_hidden_size + self .v_hidden_size
72
+
75
73
@script ()
76
74
def attention (input , weight , bias ):
77
75
QKV_no_bias = op .MatMul (input , weight )
@@ -96,9 +94,7 @@ def to_proto(self, model_script):
96
94
D_qkv = self .q_hidden_size + self .k_hidden_size + self .v_hidden_size
97
95
return model_script .to_model_proto (
98
96
input_types = (FLOAT ["B" , "S" , D ], FLOAT [D , D_qkv ], FLOAT [D_qkv ]),
99
- output_types = (
100
- FLOAT ["B" , "S" , self .v_hidden_size ],
101
- ),
97
+ output_types = (FLOAT ["B" , "S" , self .v_hidden_size ],),
102
98
)
103
99
104
100
def test_equivalence (self ):
0 commit comments