Skip to content

Commit 5bb0389

Browse files
committed
Update on "[ET-VK] Introduce graph runtime shader library that enables dynamic shapes"
## Context pytorch/pytorch#121598 introduces the ability to support dynamic shapes through tensor metadata updates. The idea is fairly simple. Instead of shaders accepting a UBO with size data for all arguments: ``` layout(set = 0, binding = 2) uniform PRECISION restrict Block { ivec4 output_sizes; ivec4 other_sizes; float alpha; } ``` Shaders will accept separate UBOs for each piece of tensor metadata: ``` layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { ivec4 data; } out_sizes; layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { ivec4 data; } in_sizes; layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { ivec4 data; } other_sizes; layout(set = 0, binding = 6) uniform PRECISION restrict Alpha { float data; } alpha; ``` Each UBO will be owned and maintained by the corresponding `vTensor` instance. To support a graph input resize, every tensor in the graph only needs to update their metadata UBOs via the `tensor.virtual_resize(new_sizes)` call. Shader dispatches in subsequent command buffer submissions will then see the updated metadata and execute as if the tensor were the updated sizes. This changeset introduces a new shader library for the Vulkan graph runtime that enables dynamic shapes through this technique in favor of relying on the shader library from PyTorch Vulkan. ## Considerations Technically, the UBO update technique can be applied to the shaders from PyTorch Vulkan as well. If that's the case, why introduce a new shader library for the graph runtime? The primary motivation is code quality. First, having `vTensor` supply UBOs for their own metadata greatly reduces the need to have operator specifc ad-hoc `Params` structs to organize arguments to write into a `api::UniformParamsBuffer`. Constructing an `ExecuteNode` for binary operators is now ``` graph.execute_nodes().emplace_back(new ExecuteNode( graph, api::shader_registry().get_shader_info(kernel_name.str()), global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, {t_out.gpu_sizes_ubo(), t_in1.gpu_sizes_ubo(), t_in2.gpu_sizes_ubo(), graph.create_params_buffer(alpha_val)})) ``` instead of ``` ArithmeticParams block{ get_size_as_ivec4(t_out), get_size_as_ivec4(t_in1), get_size_as_ivec4(t_in2), alpha_val, }; api::UniformParamsBuffer params(graph.context(), block); graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, std::move(params))); ``` Another consideration is that pytorch/pytorch#115948 which was landed fairly recently enables much more expressive shader templates through the use of Python code blocks in the GLSL template. This enables shader templates that can easily express variants for different data types, packing structures, etc. Introducing a new shader library provides the opportunity to rewrite the shaders in PyTorch Vulkan in a more generic and extensible way. Differential Revision: [D54754545](https://our.internmc.facebook.com/intern/diff/D54754545/) [ghstack-poisoned]
2 parents 9be2c32 + 1a156e5 commit 5bb0389

File tree

13 files changed

+42
-869
lines changed

13 files changed

+42
-869
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load(":targets.bzl", "define_common_targets")
33

44
oncall("executorch")
55

6-
define_common_targets()
6+
define_common_targets(is_fbcode = True)
77

88
runtime.python_library(
99
name = "vulkan_preprocess",

backends/vulkan/targets.bzl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_fbcode")
2-
load("@fbsource//tools/build_defs:glob_defs.bzl", "subdir_glob")
31
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
42

5-
def get_glsl_image_format():
6-
if native.read_config("pt", "vulkan_full_precision", "0") == "0":
7-
return "rgba16f"
8-
return "rgba32f"
9-
10-
def vulkan_spv_shader_lib(name, spv_filegroup):
3+
def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False):
114
gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin"
125
glslc_path = "//caffe2/fb/vulkan/dotslash:glslc"
13-
if is_fbcode():
6+
if is_fbcode:
147
gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin"
158
glslc_path = "//caffe2/fb/vulkan/tools:glslc"
169

10+
glsl_paths = []
11+
12+
# TODO(ssjia): remove the need for subpath once subdir_glob is enabled in OSS
13+
for target, subpath in spv_filegroups.items():
14+
glsl_paths.append("$(location {})/{}".format(target, subpath))
15+
1716
genrule_cmd = [
1817
"$(exe {})".format(gen_aten_vulkan_spv_target),
19-
"--glsl-paths $(location {})".format(spv_filegroup),
20-
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
18+
"--glsl-paths {}".format(" ".join(glsl_paths)),
19+
"--output-path $OUT",
2120
"--glslc-path=$(exe {})".format(glslc_path),
2221
"--tmp-dir-path=$OUT",
2322
]
@@ -49,7 +48,7 @@ def vulkan_spv_shader_lib(name, spv_filegroup):
4948
],
5049
)
5150

52-
def define_common_targets():
51+
def define_common_targets(is_fbcode = False):
5352
runtime.genrule(
5453
name = "gen_vk_delegate_schema",
5554
srcs = [
@@ -89,14 +88,17 @@ def define_common_targets():
8988

9089
runtime.filegroup(
9190
name = "vulkan_graph_runtime_shaders",
92-
srcs = subdir_glob([
93-
("runtime/graph/ops/glsl", "*"),
91+
srcs = native.glob([
92+
"runtime/graph/ops/glsl/*",
9493
]),
9594
)
9695

9796
vulkan_spv_shader_lib(
9897
name = "vulkan_graph_runtime_shaderlib",
99-
spv_filegroup = ":vulkan_graph_runtime_shaders",
98+
spv_filegroups = {
99+
":vulkan_graph_runtime_shaders": "runtime/graph/ops/glsl",
100+
},
101+
is_fbcode = is_fbcode,
100102
)
101103

102104
runtime.cxx_library(

examples/models/llama2/quantize.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,9 @@ def linear_forward_8da4w(
916916
x, weight_int8, scales, zeros, out_features, group_size, precision
917917
):
918918
x = per_token_dynamic_quant(x)
919-
origin_x_size = x.size()
920-
x = x.reshape(-1, origin_x_size[-1])
919+
# TODO: verify and remove following reshape code
920+
# origin_x_size = x.size()
921+
# x = x.reshape(-1, origin_x_size[-1])
921922

922923
# TODO: better API
923924
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
@@ -939,8 +940,8 @@ def linear_forward_8da4w(
939940
# w_dq = w_dq.to(torch.float16)
940941
c = torch.nn.functional.linear(x, w_dq)
941942

942-
new_shape = origin_x_size[:-1] + (out_features,)
943-
c = c.reshape(new_shape)
943+
# new_shape = origin_x_size[:-1] + (out_features,)
944+
# c = c.reshape(new_shape)
944945

945946
return c
946947

@@ -1144,7 +1145,8 @@ def __init__(
11441145

11451146
def forward(self, input: torch.Tensor) -> torch.Tensor:
11461147
input = input.to(self.precision)
1147-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
1148+
# padding is removed for perf
1149+
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
11481150
return linear_forward_8da4w(
11491151
input,
11501152
self.weight,
@@ -1387,6 +1389,10 @@ def make_names_and_values_dict_func(q, qparams):
13871389

13881390
def convert_for_runtime(self, model):
13891391
replace_linear_8da4w(
1390-
model, self.groupsize, self.inner_k_tiles, self.padding_allowed
1392+
model,
1393+
self.groupsize,
1394+
self.padding_allowed,
1395+
torch.int8,
1396+
self.precision,
13911397
)
13921398
return model

examples/portable/scripts/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import argparse
1010
import logging
1111

12+
import torch
13+
1214
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
1315

1416
from ...models import MODEL_NAME_TO_MODEL
@@ -75,4 +77,5 @@ def main() -> None:
7577

7678

7779
if __name__ == "__main__":
78-
main() # pragma: no cover
80+
with torch.no_grad():
81+
main() # pragma: no cover

examples/sdk/scripts/export_bundled_program.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88

99
import argparse
1010

11-
from typing import List, Union
11+
from typing import List
1212

1313
import torch
1414

15-
from executorch.exir import (
16-
ExecutorchProgram,
17-
ExecutorchProgramManager,
18-
MultiMethodExecutorchProgram,
19-
)
15+
from executorch.exir import ExecutorchProgramManager
2016
from executorch.sdk import BundledProgram
2117
from executorch.sdk.bundled_program.config import (
2218
MethodInputType,
@@ -33,11 +29,7 @@
3329

3430

3531
def save_bundled_program(
36-
executorch_program: Union[
37-
ExecutorchProgram,
38-
MultiMethodExecutorchProgram,
39-
ExecutorchProgramManager,
40-
],
32+
executorch_program: ExecutorchProgramManager,
4133
method_test_suites: List[MethodTestSuite],
4234
output_path: str,
4335
):

exir/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_capture_legacy_do_not_use,
1111
CallSpec,
1212
capture,
13-
capture_multiple,
1413
CaptureConfig,
1514
EdgeCompileConfig,
1615
ExecutorchBackendConfig,
@@ -23,9 +22,6 @@
2322
ExecutorchProgram,
2423
ExecutorchProgramManager,
2524
ExirExportedProgram,
26-
multi_method_program_to_executorch,
27-
MultiMethodExecutorchProgram,
28-
MultiMethodExirExportedProgram,
2925
to_edge,
3026
)
3127
from executorch.exir.tracer import ExirDynamoConfig
@@ -37,7 +33,6 @@
3733
"emit_program",
3834
"EmitterOutput",
3935
"capture",
40-
"capture_multiple",
4136
"_capture_legacy_do_not_use",
4237
"CallSpec",
4338
"ExportedProgram",
@@ -49,12 +44,9 @@
4944
"EdgeProgramManager",
5045
"ExecutorchProgramManager",
5146
"edge_to_executorch_passes",
52-
"MultiMethodExirExportedProgram",
53-
"MultiMethodExecutorchProgram",
5447
"CaptureConfig",
5548
"EdgeCompileConfig",
5649
"ExecutorchBackendConfig",
5750
"Value",
58-
"multi_method_program_to_executorch",
5951
"ExirDynamoConfig",
6052
]

exir/capture/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_capture_legacy_do_not_use,
1111
CallSpec,
1212
capture,
13-
capture_multiple,
1413
)
1514

1615
from executorch.exir.capture._config import (
@@ -23,7 +22,6 @@
2322
"CallSpec",
2423
"capture",
2524
"_capture_legacy_do_not_use",
26-
"capture_multiple",
2725
"CaptureConfig",
2826
"EdgeCompileConfig",
2927
"ExecutorchBackendConfig",

exir/capture/_capture.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from collections import namedtuple
1010
from contextlib import contextmanager
1111
from types import MethodType
12-
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
12+
from typing import Any, Callable, cast, List, Optional, Tuple
1313

1414
import torch
1515
from executorch.exir.capture._config import CaptureConfig
1616
from executorch.exir.error import ExportError, ExportErrorType, InternalError
17-
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
17+
from executorch.exir.program import ExirExportedProgram
1818
from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE
1919
from executorch.exir.tracer import (
2020
_default_decomposition_table,
@@ -360,137 +360,6 @@ def convert_to_fake(x):
360360
return ExirExportedProgram(ep, False)
361361

362362

363-
@compatibility(is_backward_compatible=False)
364-
def capture_multiple(
365-
m: Union[torch.nn.Module, Callable[..., Any]],
366-
args: Union[Dict[str, Tuple[Value, ...]], Tuple[Value, ...]],
367-
config: Optional[CaptureConfig] = None,
368-
prim_getters: Optional[Set[str]] = None,
369-
dynamic_shapes: Optional[Union[Dict[str, Any], List[Any]]] = None,
370-
):
371-
"""
372-
capture_multiple traces either an nn.Module or just a callable with PyTorch
373-
operations inside and produce a single MultiMethodExirExportedProgram that
374-
can potentially have multiple entry points. When multiple entry points
375-
are traced, each of them is stored separately in the resulting
376-
MultiMethodExirExportedProgram without sharing state.
377-
378-
Args:
379-
m: the `nn.Module` or callable to trace.
380-
381-
args: Tracing example inputs.
382-
383-
When `m` is an nn.Module, `args` can be
384-
1) A dictionary that maps names of method to their tracing example inputs.
385-
in this case, all specified methods will be captured.
386-
2) A tuple. In this case, `forward` method of `m` will be captured. It is
387-
equivalent to passing {"forward", tuple-type-args}
388-
389-
When `m` is a non-Module callable, `args` must be a Tuple containing
390-
tracing example inputs.
391-
392-
config: A CaptureConfig object that specifies how to interpret the
393-
program being captured.
394-
395-
prim_getters: A set of primitive getter functions to capture the return values of
396-
397-
dynamic_shapes: Input dynamic shapes.
398-
399-
When `m` is an nn.Module, `dynamic_shapes` is a dictionary that maps names of method
400-
to their input dynamic shapes.
401-
402-
When `m` is a non-Module callable, `dynamic_shapes` is a list of input dynamic shapes.
403-
404-
Returns:
405-
A MultiMethodExirExportedProgram.
406-
407-
if `m` is an nn.Module, returned program would have multiple
408-
captured methods, each corresponding to one entry in args dictionary.
409-
410-
if `m` is a non-Module callable, returned program would have a single
411-
captured method named `forward`.
412-
413-
Raises:
414-
AssertionError if given method name do not reference a valid method
415-
on the given nn.Module.
416-
"""
417-
warnings.warn(
418-
"This function is now deprecated, please use `torch.export and exir.to_edge` instead.",
419-
DeprecationWarning,
420-
stacklevel=1,
421-
)
422-
# Normalize m and args.
423-
compile_specs = []
424-
prim_getter_cache: Optional[Dict[str, Any]] = None
425-
if isinstance(m, torch.nn.Module):
426-
if dynamic_shapes is not None:
427-
assert isinstance(
428-
dynamic_shapes, dict
429-
), f"Expected a dict for dynamic_shapes, got {type(dynamic_shapes)}"
430-
431-
if isinstance(args, tuple):
432-
compile_specs.append(
433-
CompileSpec(
434-
"forward",
435-
m.forward,
436-
args,
437-
(
438-
dynamic_shapes["forward"]
439-
if dynamic_shapes and "forward" in dynamic_shapes
440-
else None
441-
),
442-
)
443-
)
444-
else:
445-
assert isinstance(
446-
args, dict
447-
), f"Expected a tuple or Dict[str, tuple], got {type(args)}"
448-
for method_name, method_args in args.items():
449-
compile_specs.append(
450-
CompileSpec(
451-
method_name,
452-
getattr(m, method_name),
453-
method_args,
454-
(
455-
dynamic_shapes[method_name]
456-
if dynamic_shapes and method_name in dynamic_shapes
457-
else None
458-
),
459-
)
460-
)
461-
if prim_getters is not None:
462-
prim_getter_cache = {}
463-
for getter in prim_getters:
464-
prim_getter_cache[getter] = getattr(m, getter)()
465-
else:
466-
# Reaching here means `m` is a non-Module callable.
467-
assert isinstance(
468-
m, Callable
469-
), f"Only nn.Module or callable allowed, got {type(m)}"
470-
assert isinstance(
471-
args, tuple
472-
), f"When tracing a non-Module callable, `args` must be a tuple of tracing inputs, but got {type(args)}"
473-
assert (
474-
prim_getters is None
475-
), "Caller should not specify primitive getter functions when only providing a callable as input"
476-
if dynamic_shapes is not None:
477-
assert isinstance(
478-
dynamic_shapes, list
479-
), f"Expected a list for constraints, got {type(dynamic_shapes)}"
480-
compile_specs.append(CompileSpec("forward", m, args, dynamic_shapes))
481-
482-
method_name_to_prog = {}
483-
for compile_spec in compile_specs:
484-
method_name_to_prog[compile_spec.method_name] = capture(
485-
compile_spec.callable,
486-
compile_spec.args,
487-
config,
488-
compile_spec.dynamic_shapes,
489-
)
490-
491-
return MultiMethodExirExportedProgram(method_name_to_prog, prim_getter_cache)
492-
493-
494363
# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
495364
# 2. meta["val"] is not properly set in dispatch_trace.
496365
def _instantiate_missing_placeholder_val_with_real_inputs(gm, args):

exir/program/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
ExecutorchProgram,
1414
ExecutorchProgramManager,
1515
ExirExportedProgram,
16-
multi_method_program_to_executorch,
17-
MultiMethodExecutorchProgram,
18-
MultiMethodExirExportedProgram,
1916
to_edge,
2017
)
2118

@@ -25,9 +22,6 @@
2522
"_to_edge",
2623
"to_edge",
2724
"edge_to_executorch_passes",
28-
"MultiMethodExirExportedProgram",
29-
"MultiMethodExecutorchProgram",
30-
"multi_method_program_to_executorch",
3125
"EdgeProgramManager",
3226
"ExecutorchProgramManager",
3327
]

0 commit comments

Comments
 (0)