Skip to content

Commit 1d81cd0

Browse files
committed
[ET-VK] Enable Partial GPU lowering via Vulkan in stories model export
## Context Simple change to add Vulkan Partitioner as a dependency for the llama exporter and runner, and provide a command line flag to invoke the vulkan partitioner during export. Included a small change to the Vulkan serializer which was needed for everything to work (i.e. enable serializing multiple graph outputs). Differential Revision: [D54805831](https://our.internmc.facebook.com/intern/diff/D54805831/) ghstack-source-id: 218336771 Pull Request resolved: #2368
1 parent 9bc9d81 commit 1d81cd0

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,14 @@ def process_getattr_node(self, node: Node) -> None:
218218
self.create_tensor_values(node)
219219

220220
def process_output_node(self, node: Node) -> None:
221-
if node.all_input_nodes[0] not in self.node_to_value_ids:
222-
raise AssertionError(
223-
"Cannot find input to output node in node_to_value_ids. This means the "
224-
"output node is being serialized before its corresponding internal node "
225-
"which is not allowed."
226-
)
227-
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])
221+
for out_node in node.all_input_nodes:
222+
if out_node not in self.node_to_value_ids:
223+
raise AssertionError(
224+
"Cannot find input to output node in node_to_value_ids. This means "
225+
"the output node is being serialized before its corresponding "
226+
"internal node which is not allowed."
227+
)
228+
self.output_ids.append(self.node_to_value_ids[out_node])
228229

229230
def process_node(self, node: Node) -> None:
230231
if node.op == "placeholder":

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ runtime.python_library(
8282
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
8383
"//executorch/backends/xnnpack:xnnpack_backend",
8484
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
85+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
8586
"//executorch/examples/models:model_base",
8687
"//executorch/examples/models:models",
8788
"//executorch/examples/portable:utils",

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pkg_resources
1919
import torch
20+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2021
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2122
XnnpackDynamicallyQuantizedPartitioner,
2223
)
@@ -356,6 +357,7 @@ def build_args_parser() -> argparse.ArgumentParser:
356357
parser.add_argument("-2", "--fairseq2", action="store_true")
357358
parser.add_argument("-v", "--verbose", action="store_true")
358359
parser.add_argument("-X", "--xnnpack", action="store_true")
360+
parser.add_argument("-V", "--vulkan", action="store_true")
359361

360362
return parser
361363

@@ -451,6 +453,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
451453
)
452454
# partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner()
453455
modelname = f"xnnpack_{modelname}"
456+
if args.vulkan:
457+
partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
458+
modelname = f"vulkan_{modelname}"
454459

455460
builder = (
456461
load_llama_model(

examples/models/llama2/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
],
3030
exported_deps = [
3131
"//executorch/backends/xnnpack:xnnpack_backend",
32+
"//executorch/backends/vulkan:vulkan_backend_lib",
3233
"//executorch/examples/models/llama2/sampler:sampler" + aten_suffix,
3334
"//executorch/examples/models/llama2/tokenizer:tokenizer",
3435
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,

0 commit comments

Comments
 (0)