Skip to content

Commit 54caa37

Browse files
committed
Qualcomm AI Engine Direct - Optimize the performance for AR-N model
Summary: - Fix the bug of rms norm builder - Use HuggingFace version RoPE to improve the performance due to stride = 1 in StrideSlice Op - Modificate the axis order of the conv in qkv, feedforward and output - Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048) - New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048)
1 parent 51901f3 commit 54caa37

File tree

5 files changed

+73
-51
lines changed

5 files changed

+73
-51
lines changed

backends/qualcomm/_passes/fuse_consecutive_transpose.py

+16-25
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ def _clone_transpose(
5555
clone_permute_node.meta = n.meta
5656
users[i].replace_input_with(n, clone_permute_node)
5757

58-
def _is_dispensable(self, axis_order):
59-
for index, value in enumerate(axis_order):
60-
if index != value:
61-
return False
62-
return True
63-
6458
def _traverse(self, node):
6559
if node in self.visited or node.target not in self.op_map:
6660
return
@@ -87,25 +81,22 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
8781
axis_order = torch.arange(len(input_shape)).tolist()
8882
for node in self.nodes:
8983
axis_order = [axis_order[i] for i in node.args[1]]
90-
# If axis order is just [0,1,2,3], we ignore permute node
91-
if self._is_dispensable(axis_order):
92-
for user in output_node.users.copy():
93-
user.replace_input_with(output_node, n.args[0])
94-
else:
95-
with graph.inserting_after(input_node):
96-
permute_op = exir_ops.edge.aten.permute_copy.default
97-
permute_node = graph.create_node(
98-
"call_function", permute_op, (input_node, axis_order)
99-
)
100-
users = output_node.users.copy()
101-
for user in users:
102-
user.replace_input_with(output_node, permute_node)
103-
104-
# copy metadata
105-
permute_node.meta = output_node.meta
106-
# Without "qnn_permute", we might obtain wrong input shape
107-
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
108-
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
84+
85+
# Reserve [0,1,2,3] permute node to ensure the next node get the right axis order.
86+
with graph.inserting_after(input_node):
87+
permute_op = exir_ops.edge.aten.permute_copy.default
88+
permute_node = graph.create_node(
89+
"call_function", permute_op, (input_node, axis_order)
90+
)
91+
users = output_node.users.copy()
92+
for user in users:
93+
user.replace_input_with(output_node, permute_node)
94+
95+
# copy metadata
96+
permute_node.meta = output_node.meta
97+
# Without "qnn_permute", we might obtain wrong input shape
98+
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
99+
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
109100

110101
# clear current stack
111102
self.nodes = []

backends/qualcomm/_passes/recompose_rms_norm.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
78
from executorch.exir.dialects._ops import ops as exir_ops
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass):
1617
Merge decomposed operators back to one super node.
1718
"""
1819

19-
def __init__(self):
20-
super().__init__()
20+
def __init__(self, edge_program: torch.export.ExportedProgram):
21+
super(RecomposeRmsNorm, self).__init__()
22+
self.edge_program = edge_program
2123

2224
def _get_eps_node(self, nodes):
2325
# eps: one of inputs of add node
@@ -47,11 +49,15 @@ def call(self, graph_module: torch.fx.GraphModule):
4749
input_node = inp_0 if len(inp_0.users) == 2 else inp_1
4850
else:
4951
raise RuntimeError(
50-
f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs"
52+
f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs"
5153
)
5254

5355
output_node = src_partition.output_nodes[0]
54-
eps_node = self._get_eps_node(src_partition.nodes)
56+
eps = self._get_eps_node(src_partition.nodes)
57+
if isinstance(eps, torch.fx.Node) and is_parameter(
58+
eps, self.edge_program
59+
):
60+
eps = get_parameter(eps, self.edge_program).item()
5561
gamma_node = self._get_gamma_node(output_node)
5662

5763
with graph.inserting_before(output_node):
@@ -64,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6470
input_node,
6571
list(gamma_node.meta["val"].shape),
6672
gamma_node,
67-
eps_node,
73+
eps,
6874
),
6975
)
7076
users = output_node.users.copy()

backends/qualcomm/builders/op_rms_norm.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import torch
1414
from executorch.backends.qualcomm.builders.utils import get_parameter
15-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
15+
from executorch.backends.qualcomm.utils.constants import (
16+
QCOM_DATA,
17+
QCOM_QUANT_ATTRS,
18+
QCOM_ZERO_POINT,
19+
)
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721

1822
from .node_visitor import NodeVisitor, register_node_visitor
@@ -66,7 +70,7 @@ def define_node(
6670
nodes_to_wrappers,
6771
)
6872

69-
# Fake node, nn module seems to be inconsistant with document
73+
# Fake node, nn module seems to be inconsistent with document
7074
bias_tensor = torch.zeros(weight_tensor.shape)
7175
bias_node = torch.fx.Node(
7276
node.graph,
@@ -78,6 +82,7 @@ def define_node(
7882
)
7983
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
8084
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
85+
bias_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = 0
8186
bias_tensor_wrapper = self.define_tensor(
8287
bias_node,
8388
node,
@@ -87,14 +92,6 @@ def define_node(
8792
)
8893

8994
epsilon = node.args[3]
90-
if isinstance(epsilon, torch.fx.Node):
91-
epsilon = get_parameter(epsilon, self.edge_program)
92-
epsilon = (
93-
epsilon
94-
if isinstance(epsilon, float)
95-
else torch.finfo(epsilon.dtype).eps
96-
)
97-
9895
output_tensor = self.get_tensor(node, node)
9996
output_tensor_wrapper = self.define_tensor(
10097
node,

examples/qualcomm/oss_scripts/llama/llama.py

+22
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,28 @@ def compile(args, pte_filename, tokenizer):
539539
if "model" in state_dict:
540540
state_dict = state_dict["model"]
541541

542+
# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
543+
def permute(w, heads):
544+
dim_0 = w.size(0)
545+
dim_1 = w.size(1)
546+
return (
547+
w.view(heads, dim_0 // heads // 2, 2, dim_1)
548+
.transpose(1, 2)
549+
.reshape(dim_0, dim_1)
550+
)
551+
552+
n_heads = llama_instance_list[0].n_heads
553+
n_kv_heads = llama_instance_list[0].n_kv_heads
554+
n_layers = llama_instance_list[0].n_layers
555+
556+
for layer_i in range(n_layers):
557+
state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute(
558+
state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads
559+
)
560+
state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute(
561+
state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads
562+
)
563+
542564
for llama_instance in llama_instance_list:
543565
llama_instance.load_state_dict(
544566
state_dict,

examples/qualcomm/oss_scripts/llama/model/static_llama.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
def apply_rotary_emb_single(
2020
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
2121
) -> torch.Tensor:
22-
x_r, x_i = x[..., ::2], x[..., 1::2]
23-
22+
# Change to RoPE of huggingface version
23+
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
2424
# brodcast for batch_prefill mode input x
2525
if x.dim() == 4:
26-
freqs_cos = freqs_cos[None, :, None, :]
27-
freqs_sin = freqs_sin[None, :, None, :]
26+
freqs_cos = freqs_cos[None, None, :, :]
27+
freqs_sin = freqs_sin[None, None, :, :]
2828
x_out_r = x_r * freqs_cos - x_i * freqs_sin
2929
x_out_i = x_r * freqs_sin + x_i * freqs_cos
3030

@@ -108,21 +108,27 @@ def forward_sha(
108108
hidden_states, (bsz, seq_len, 1, self.dim)
109109
).transpose(1, 3)
110110
q = [
111-
wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
111+
wq_sha(hidden_states)
112+
.permute(0, 2, 3, 1)
113+
.reshape(bsz, seq_len, self.head_dim)
112114
for wq_sha in self.wq_sha
113115
]
114116
k = [
115-
wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
117+
wk_sha(hidden_states)
118+
.permute(0, 2, 3, 1)
119+
.reshape(bsz, seq_len, self.head_dim)
116120
for wk_sha in self.wk_sha
117121
]
118122
v = [
119-
wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
123+
wv_sha(hidden_states)
124+
.permute(0, 2, 3, 1)
125+
.reshape(bsz, seq_len, self.head_dim)
120126
for wv_sha in self.wv_sha
121127
]
122128
for i in range(len(q)):
123129
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
124130
for i in range(len(k)):
125-
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1)
131+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
126132

127133
output_y = []
128134
kh, vh = [], []
@@ -249,10 +255,10 @@ def prepare_feedfoward_conv(self):
249255

250256
def forward_feedfoward_conv(self, x):
251257
bsz, _, _ = x.size()
252-
x = torch.reshape(x, (bsz, -1, self.dim, 1))
253-
x = x.transpose(1, 2) # Transpose right before and after Conv
258+
x = torch.reshape(x, (bsz, -1, 1, self.dim))
259+
x = x.transpose(1, 3) # Transpose right before and after Conv
254260
x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
255-
x = x.transpose(1, 2)
261+
x = x.transpose(1, 3)
256262
x = torch.reshape(x, (bsz, -1, self.dim))
257263
return x
258264

0 commit comments

Comments
 (0)