From d6063fc10012061724e148cd26bbf6256fb9ace3 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 26 Mar 2025 13:56:55 -0700 Subject: [PATCH] [ET-VK] Cache compiled SPIR-V and only recompile when changed Pull Request resolved: https://github.com/pytorch/executorch/pull/9652 ## Context As title. Introduce a caching mechanism to store SPIR-V generated in a previous build to avoid re-compiling them. A shader will not be re-compiled if the source GLSL template is unchanged. For Meta internal builds, the `etvk.force_shader_rebuild` buck config can be used to force all shaders to be built without using the cache. ghstack-source-id: 274281888 @exported-using-ghexport Differential Revision: [D71916745](https://our.internmc.facebook.com/intern/diff/D71916745/) --- backends/vulkan/cmake/ShaderLibrary.cmake | 2 +- backends/vulkan/runtime/gen_vulkan_spv.py | 72 +++++++++++++++++++++-- backends/vulkan/targets.bzl | 3 +- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/cmake/ShaderLibrary.cmake b/backends/vulkan/cmake/ShaderLibrary.cmake index bbf81e7bcba..67285738b4c 100644 --- a/backends/vulkan/cmake/ShaderLibrary.cmake +++ b/backends/vulkan/cmake/ShaderLibrary.cmake @@ -52,7 +52,7 @@ function(gen_vulkan_shader_lib_cpp shaders_path) "${PYTHON_EXECUTABLE}" ${EXECUTORCH_ROOT}/backends/vulkan/runtime/gen_vulkan_spv.py --glsl-path ${shaders_path} --output-path ${VULKAN_SHADERGEN_OUT_PATH} - --glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH} + --glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}/shader_cache/ --env ${VULKAN_GEN_ARG_ENV} RESULT_VARIABLE error_code ) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index a3d214f5ae8..594e5f6ad44 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -12,9 +12,11 @@ import codecs import copy import glob +import hashlib import io import os import re +import shutil import sys from itertools import product from multiprocessing.pool import ThreadPool @@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str: input_text = input_text.replace("uint16_t", "int") return input_text - def generateSPV(self, output_dir: str) -> Dict[str, str]: + def get_md5_checksum(self, file_path: str) -> str: + # Use a reasonably sized buffer for better performance with large files + BUF_SIZE = 65536 # 64kb chunks + + md5 = hashlib.md5() + + with open(file_path, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + + # Get the hexadecimal digest and compare + file_md5 = md5.hexdigest() + return file_md5 + + def generateSPV( # noqa: C901 + self, + output_dir: str, + cache_dir: Optional[str] = None, + force_rebuild: bool = False, + ) -> Dict[str, str]: output_file_map = {} def process_shader(shader_paths_pair): @@ -742,20 +766,48 @@ def process_shader(shader_paths_pair): source_glsl = shader_paths_pair[1][0] shader_params = shader_paths_pair[1][1] + glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") + spv_out_path = os.path.join(output_dir, f"{shader_name}.spv") + + if cache_dir is not None: + cached_source_glsl = os.path.join( + cache_dir, os.path.basename(source_glsl) + ".t" + ) + cached_glsl_out_path = os.path.join(cache_dir, f"{shader_name}.glsl") + cached_spv_out_path = os.path.join(cache_dir, f"{shader_name}.spv") + if ( + not force_rebuild + and os.path.exists(cached_source_glsl) + and os.path.exists(cached_glsl_out_path) + and os.path.exists(cached_spv_out_path) + ): + current_checksum = self.get_md5_checksum(source_glsl) + cached_checksum = self.get_md5_checksum(cached_source_glsl) + # If the cached source GLSL template is the same as the current GLSL + # source file, then assume that the generated GLSL and SPIR-V will + # not have changed. In that case, just copy over the GLSL and SPIR-V + # files from the cache. + if current_checksum == cached_checksum: + shutil.copyfile(cached_spv_out_path, spv_out_path) + shutil.copyfile(cached_glsl_out_path, glsl_out_path) + return (spv_out_path, glsl_out_path) + with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: input_text = input_file.read() input_text = self.maybe_replace_u16vecn(input_text) output_text = preprocess(input_text, shader_params) - glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file: output_file.write(output_text) + if cache_dir is not None: + # Otherwise, store the generated and source GLSL files in the cache + shutil.copyfile(source_glsl, cached_source_glsl) + shutil.copyfile(glsl_out_path, cached_glsl_out_path) + # If no GLSL compiler is specified, then only write out the generated GLSL shaders. # This is mainly for testing purposes. if self.glslc_path is not None: - spv_out_path = os.path.join(output_dir, f"{shader_name}.spv") - cmd_base = [ self.glslc_path, "-fshader-stage=compute", @@ -788,6 +840,9 @@ def process_shader(shader_paths_pair): else: raise RuntimeError(f"{err_msg_base} {e.stderr}") from e + if cache_dir is not None: + shutil.copyfile(spv_out_path, cached_spv_out_path) + return (spv_out_path, glsl_out_path) # Parallelize shader compilation as much as possible to optimize build time. @@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int: default=["."], ) parser.add_argument("-c", "--glslc-path", required=True, help="") - parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") + parser.add_argument( + "-t", "--tmp-dir-path", required=True, help="/tmp/vulkan_shaders/" + ) parser.add_argument("-o", "--output-path", required=True, help="") + parser.add_argument("-f", "--force-rebuild", action="store_true", default=False) parser.add_argument("--replace-u16vecn", action="store_true", default=False) parser.add_argument("--optimize_size", action="store_true", help="") parser.add_argument("--optimize", action="store_true", help="") @@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int: glslc_flags=glslc_flags_str, replace_u16vecn=options.replace_u16vecn, ) - output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) + output_spv_files = shader_generator.generateSPV( + options.output_path, options.tmp_dir_path, options.force_rebuild + ) genCppFiles( output_spv_files, diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index d2314d138bf..aafc87ad2c3 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -44,7 +44,8 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False, no_volk = Fal "--glsl-paths {} ".format(" ".join(glsl_paths)) + "--output-path $OUT " + "--glslc-path=$(exe {}) ".format(glslc_path) + - "--tmp-dir-path=$OUT " + + "--tmp-dir-path=shader_cache " + + ("-f " if read_config("etvk", "force_shader_rebuild", "0") == "1" else " ") + select({ "DEFAULT": "", "ovr_config//os:android": "--optimize",