Skip to content

local change to export llama to qnn #2985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer

from executorch.backends.qualcomm.passes.convert_constants_to_attrs import ConvertConstantsToAttrs
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch import Tensor
from torch._ops import OpOverload
from torch.ao.quantization.observer import (
Expand Down Expand Up @@ -378,8 +379,58 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
model = DecomposeScaledDotProductAttention()(model).graph_module
model = DecomposeSilu()(model).graph_module
model = ReplaceInfBuffer()(model).graph_module
# ConvertConstantsToAttrs(model)
self._lift_constant_scalar_operands(model)
# model = ConvertBinaryOpsWithScalar()(model).graph_module

return model

def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None:
# print("running _lift_constant_scalar_operands...")
for n in gm.graph.nodes:
# if n.name == "mul_78":
# print(" n.name: ", n.name)

if n.op != "call_function" or n.target not in (
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul.Scalar,
torch.ops.aten.rsub.Scalar,
):
continue

# print(" handling n: ", n, " n.target: ", n.target, " n.args: ", n.args)
const_arg = None
non_const_arg = None
for arg in n.args:
if isinstance(arg, torch.fx.Node):
non_const_arg = arg
else:
const_arg = arg

if non_const_arg is None or const_arg is None:
continue

# print(" n'args are all constant: ", n)
tensor_constant = torch.tensor([const_arg], dtype=torch.float32)
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(
gm
)
gm.register_buffer(tensor_constant_name, tensor_constant)

fake_mode = n.meta["val"].fake_mode
with gm.graph.inserting_before(n):
get_attr_node = gm.graph.get_attr(tensor_constant_name)
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)

if n.target == torch.ops.aten.rsub.Scalar:
n.args = (get_attr_node, non_const_arg) + n.args[2:]
n.target = torch.ops.aten.sub.Tensor
else:
n.args = (non_const_arg, get_attr_node) + n.args[2:]

gm.recompile()

def validate(self, model: GraphModule) -> None:
pass
18 changes: 15 additions & 3 deletions backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, Dict, List, Optional, Sequence

import torch
from torch._subclasses import FakeTensor

from torch._ops import OpOverload

Expand Down Expand Up @@ -41,6 +42,13 @@ def decorator(annotator: Callable):

return decorator

def _is_input_non_float_tensor(node: Node):
"""Check if the input is not a float tensor, so that we can skip quantization for the node
since observers only works with float Tensors
"""
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True
return node.meta["val"].dtype != torch.float32

def _is_annotated(nodes: List[Node]):
"""
Expand Down Expand Up @@ -115,6 +123,7 @@ def annotate_single_in_single_out(


def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
print(f"annotate_binary running for node {node}...")
if _is_annotated([node]):
return

Expand All @@ -123,12 +132,14 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None

input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
if isinstance(input_act0, Node) and not _is_input_non_float_tensor(input_act0):
input_qspec_map[input_act0] = input_act_qspec
print(" input_act0: ", input_act0, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act0))

input_act1 = node.args[1]
if isinstance(input_act1, Node):
if isinstance(input_act1, Node) and not _is_input_non_float_tensor(input_act1):
input_qspec_map[input_act1] = input_act_qspec
print(" input_act1: ", input_act1, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act1))

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
Expand All @@ -147,7 +158,8 @@ def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar])
# @register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar])
@register_annotator([torch.ops.aten.mul.Tensor])
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)

Expand Down
9 changes: 5 additions & 4 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,17 +647,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
generate_qnn_executorch_compiler_spec(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
soc_model=QcomChipset.SM8650, # default to SM8650
soc_model=QcomChipset.SM8450, # default to SM8650
backend_options=backend_options,
debug=False,
saver=False,
),
skip_node_id_set={},
skip_node_op_set={},
skip_node_op_set={"aten.unsqueeze_copy.default", "aten.permute_copy.default"},
)
)
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
_transform(builder_exported_to_edge.export_program())
_transform(builder_exported_to_edge.edge_manager.exported_program())

if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
Expand All @@ -678,7 +678,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()

print("graph after to_backend")
builder.edge_manager.exported_program().graph.print_tabular()
if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")

Expand Down
55 changes: 39 additions & 16 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F

from torch import nn

import math

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
Expand Down Expand Up @@ -216,15 +216,23 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
self.layer_id = layer_id

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
dtype=torch.bool,
device="cpu",
)
# causal_mask = torch.tril(
# torch.ones(
# self.max_seq_len,
# self.max_seq_len,
# dtype=torch.bool,
# device="cpu",
# )
# )
# self.register_buffer("mask", causal_mask, persistent=False)
mask = torch.full(
(1, 1, args.max_seq_len, args.max_seq_len),
float("-inf"),
device="cpu",
)
self.register_buffer("mask", causal_mask, persistent=False)

mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)

if self.use_kv_cache:
self.kv_cache = KVCache(
Expand Down Expand Up @@ -264,18 +272,33 @@ def forward(
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
mask = self.mask[None, None, input_pos]
mask = torch.squeeze(self.mask, [0, 1])
mask = mask[None, None, input_pos]
# mask = self.mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(q)
scores = scores + mask
output = torch.matmul(
scores, v
) # (bs, n_local_heads, seqlen, head_dim)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)
return output
# y = F.scaled_dot_product_attention(
# q, k, v, attn_mask=mask, dropout_p=0.0
# )

# y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

# y = self.wo(y)
# return y
else:
from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa

Expand Down
Loading