Skip to content

Commit 13b49c3

Browse files
pytorchbotSS-JIA
authored andcommitted
[ET-VK] Cache compiled SPIR-V and only recompile when changed (#9701)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9652 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/202/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/202/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/202/orig @diff-train-skip-merge Co-authored-by: Stephen Jia <[email protected]>
1 parent d0e8dbf commit 13b49c3

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

backends/vulkan/cmake/ShaderLibrary.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function(gen_vulkan_shader_lib_cpp shaders_path)
5252
"${PYTHON_EXECUTABLE}"
5353
${EXECUTORCH_ROOT}/backends/vulkan/runtime/gen_vulkan_spv.py --glsl-path
5454
${shaders_path} --output-path ${VULKAN_SHADERGEN_OUT_PATH}
55-
--glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}
55+
--glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}/shader_cache/
5656
--env ${VULKAN_GEN_ARG_ENV}
5757
RESULT_VARIABLE error_code
5858
)

backends/vulkan/runtime/gen_vulkan_spv.py

+66-6
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
import codecs
1313
import copy
1414
import glob
15+
import hashlib
1516
import io
1617
import os
1718
import re
19+
import shutil
1820
import sys
1921
from itertools import product
2022
from multiprocessing.pool import ThreadPool
@@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
733735
input_text = input_text.replace("uint16_t", "int")
734736
return input_text
735737

736-
def generateSPV(self, output_dir: str) -> Dict[str, str]:
738+
def get_md5_checksum(self, file_path: str) -> str:
739+
# Use a reasonably sized buffer for better performance with large files
740+
BUF_SIZE = 65536 # 64kb chunks
741+
742+
md5 = hashlib.md5()
743+
744+
with open(file_path, "rb") as f:
745+
while True:
746+
data = f.read(BUF_SIZE)
747+
if not data:
748+
break
749+
md5.update(data)
750+
751+
# Get the hexadecimal digest and compare
752+
file_md5 = md5.hexdigest()
753+
return file_md5
754+
755+
def generateSPV( # noqa: C901
756+
self,
757+
output_dir: str,
758+
cache_dir: Optional[str] = None,
759+
force_rebuild: bool = False,
760+
) -> Dict[str, str]:
737761
output_file_map = {}
738762

739763
def process_shader(shader_paths_pair):
@@ -742,20 +766,48 @@ def process_shader(shader_paths_pair):
742766
source_glsl = shader_paths_pair[1][0]
743767
shader_params = shader_paths_pair[1][1]
744768

769+
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
770+
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
771+
772+
if cache_dir is not None:
773+
cached_source_glsl = os.path.join(
774+
cache_dir, os.path.basename(source_glsl) + ".t"
775+
)
776+
cached_glsl_out_path = os.path.join(cache_dir, f"{shader_name}.glsl")
777+
cached_spv_out_path = os.path.join(cache_dir, f"{shader_name}.spv")
778+
if (
779+
not force_rebuild
780+
and os.path.exists(cached_source_glsl)
781+
and os.path.exists(cached_glsl_out_path)
782+
and os.path.exists(cached_spv_out_path)
783+
):
784+
current_checksum = self.get_md5_checksum(source_glsl)
785+
cached_checksum = self.get_md5_checksum(cached_source_glsl)
786+
# If the cached source GLSL template is the same as the current GLSL
787+
# source file, then assume that the generated GLSL and SPIR-V will
788+
# not have changed. In that case, just copy over the GLSL and SPIR-V
789+
# files from the cache.
790+
if current_checksum == cached_checksum:
791+
shutil.copyfile(cached_spv_out_path, spv_out_path)
792+
shutil.copyfile(cached_glsl_out_path, glsl_out_path)
793+
return (spv_out_path, glsl_out_path)
794+
745795
with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
746796
input_text = input_file.read()
747797
input_text = self.maybe_replace_u16vecn(input_text)
748798
output_text = preprocess(input_text, shader_params)
749799

750-
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
751800
with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
752801
output_file.write(output_text)
753802

803+
if cache_dir is not None:
804+
# Otherwise, store the generated and source GLSL files in the cache
805+
shutil.copyfile(source_glsl, cached_source_glsl)
806+
shutil.copyfile(glsl_out_path, cached_glsl_out_path)
807+
754808
# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
755809
# This is mainly for testing purposes.
756810
if self.glslc_path is not None:
757-
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
758-
759811
cmd_base = [
760812
self.glslc_path,
761813
"-fshader-stage=compute",
@@ -788,6 +840,9 @@ def process_shader(shader_paths_pair):
788840
else:
789841
raise RuntimeError(f"{err_msg_base} {e.stderr}") from e
790842

843+
if cache_dir is not None:
844+
shutil.copyfile(spv_out_path, cached_spv_out_path)
845+
791846
return (spv_out_path, glsl_out_path)
792847

793848
# Parallelize shader compilation as much as possible to optimize build time.
@@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int:
10891144
default=["."],
10901145
)
10911146
parser.add_argument("-c", "--glslc-path", required=True, help="")
1092-
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
1147+
parser.add_argument(
1148+
"-t", "--tmp-dir-path", required=True, help="/tmp/vulkan_shaders/"
1149+
)
10931150
parser.add_argument("-o", "--output-path", required=True, help="")
1151+
parser.add_argument("-f", "--force-rebuild", action="store_true", default=False)
10941152
parser.add_argument("--replace-u16vecn", action="store_true", default=False)
10951153
parser.add_argument("--optimize_size", action="store_true", help="")
10961154
parser.add_argument("--optimize", action="store_true", help="")
@@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int:
11311189
glslc_flags=glslc_flags_str,
11321190
replace_u16vecn=options.replace_u16vecn,
11331191
)
1134-
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
1192+
output_spv_files = shader_generator.generateSPV(
1193+
options.output_path, options.tmp_dir_path, options.force_rebuild
1194+
)
11351195

11361196
genCppFiles(
11371197
output_spv_files,

backends/vulkan/targets.bzl

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False, no_volk = Fal
4444
"--glsl-paths {} ".format(" ".join(glsl_paths)) +
4545
"--output-path $OUT " +
4646
"--glslc-path=$(exe {}) ".format(glslc_path) +
47-
"--tmp-dir-path=$OUT " +
47+
"--tmp-dir-path=shader_cache " +
48+
("-f " if read_config("etvk", "force_shader_rebuild", "0") == "1" else " ") +
4849
select({
4950
"DEFAULT": "",
5051
"ovr_config//os:android": "--optimize",

0 commit comments

Comments
 (0)