Skip to content

[Draft] Qualcomm AI Engine Direct -Enable story llama model in quantied and fp #4030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

allow_list_operator = [
Expand Down
17 changes: 2 additions & 15 deletions backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,7 @@ def __init__(
):
self.node_visitors = node_visitor.get_node_visitors(edge_program)

self.skip_node_op_builder_set = set()
if skip_node_op_set is not None:
self.skip_node_op_builder_set = set(
[
self.node_visitors[val]
for val in skip_node_op_set
if val in self.node_visitors
]
)

self.skip_node_op_set = skip_node_op_set
self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = defaultdict(dict)
self.qnn_manager = PyQnnManager.QnnManager(
Expand All @@ -69,11 +60,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
return False

if (
self.skip_node_op_builder_set is not None
and self.node_visitors[node.target.__name__]
in self.skip_node_op_builder_set
):
if self.skip_node_op_set is not None and node.target.__name__ in self.skip_node_op_set:
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
return False

Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.hardsigmoid.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten._log_softmax.default,
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/passes/replace_inf_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def __init__(self):
def call(self, graph_module: torch.fx.GraphModule):
for buf_name, tensor in graph_module.named_buffers():
if tensor.is_floating_point():
tensor[tensor == float("inf")] = torch.finfo(torch.float32).max
tensor[tensor == float("-inf")] = torch.finfo(torch.float32).min
# An arbitrary number
tensor[tensor == float("inf")] = 1000
tensor[tensor == float("-inf")] = -1000
setattr(graph_module, buf_name, tensor)

graph_module.recompile()
Expand Down
43 changes: 36 additions & 7 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
from .source_transformation.sdpa import (
replace_causal_mask,
replace_kv_cache_with_simple_kv_cache,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_simple_sdpa,
)

Expand Down Expand Up @@ -118,7 +120,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--embedding-quantize",
default=None,
type=str,
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
help="type of embedding quantization, '<bitwidth>,<groupsize>,<qge_dtype>', e.g., '8,1024,32'.",
)
parser.add_argument(
"--pt2e_quantize",
Expand Down Expand Up @@ -197,6 +199,12 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
)
parser.add_argument(
"--num_sharding",
type=int,
default=None,
help="Specify the number of splits which is generated with custom op. Expect to be able to divide num layer.",
)
parser.add_argument(
"--disable_dynamic_shape",
dest="enable_dynamic_shape",
Expand Down Expand Up @@ -385,7 +393,12 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
transforms.append(replace_sdpa_with_custom_op)

if args.use_kv_cache:
if args.qnn or args.coreml or args.mps:
if args.qnn:
transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
Expand Down Expand Up @@ -486,18 +499,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
modelname = f"coreml_{modelname}"

if args.qnn:
from executorch.extension.llm.custom_ops import model_sharding

partitioners.append(
get_qnn_partitioner(
quant_dtype,
args.use_kv_cache,
args.pt2e_quantize,
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
)
)
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import _transform

# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
_transform(builder_exported_to_edge.edge_manager.exported_program())
if args.num_sharding is not None:
model_sharding.split_graph(
builder_exported_to_edge.edge_manager.exported_program(),
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
)

if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
Expand All @@ -506,7 +525,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
logging.info("Generating etrecord")
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding is not None:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

canonicalize_program(builder.edge_manager.exported_program())
builder = builder.to_executorch()

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -517,7 +541,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding is not None:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

canonicalize_program(builder.edge_manager.exported_program())
builder = builder.to_executorch()

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def __init__(
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_example_inputs_kvcache_sdpa(self):
if self.enable_dynamic_shape:
return (
torch.tensor([[2, 3, 4]], dtype=torch.long),
torch.tensor([0], dtype=torch.long),
torch.tensor([0, 1, 2], dtype=torch.long),
)
else:
return (
Expand Down
22 changes: 18 additions & 4 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


def replace_embedding_weight_only_grouped_int8_per_channel(
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False, qge_dtype=torch.half
):
for name, child in module.named_children():
# print(f"name: {name}")
Expand All @@ -400,11 +400,12 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
embedding_dim=child.weight.shape[1],
group_size=group_size,
packed=packed,
dtype=qge_dtype,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, device, bitwidth, group_size, packed
child, device, bitwidth, group_size, packed, qge_dtype
)


Expand All @@ -417,6 +418,7 @@ def __init__(
bitwidth: int = 8,
group_size: Optional[int] = None,
packed=False,
qge_dtype=torch.half,
):
if isinstance(packed, str):
packed = packed == "True"
Expand All @@ -425,6 +427,7 @@ def __init__(
self.group_size = group_size
self.bitwidth = bitwidth
self.packed = packed
self.qge_dtype = qge_dtype
if (bitwidth != 4) and packed:
raise RuntimeError("pack only works with bitsize 4")

Expand Down Expand Up @@ -484,7 +487,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.device, self.bitwidth, self.group_size, self.packed
self.mod, self.device, self.bitwidth, self.group_size, self.packed, self.qge_dtype
)
return self.mod

Expand Down Expand Up @@ -554,17 +557,28 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:


def get_quant_embedding_transform(args):
bitwidth, group_size = args.embedding_quantize.split(",")
quant_args = [a.strip() for a in args.embedding_quantize.split(",")]
bitwidth, group_size = quant_args[:2]
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
group_size = int(group_size)
bitwidth = int(bitwidth)

if len(quant_args) == 3:
qge_dtype = quant_args[2]
if qge_dtype in ("32", "torch.float32"):
qge_dtype = torch.float32
else:
print(f"Use default qge_dtype, {torch.half}")
qge_dtype = torch.half

return lambda model: EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth == 4),
qge_dtype=qge_dtype,
).quantized_model()


Expand Down
Loading
Loading