12
12
import codecs
13
13
import copy
14
14
import glob
15
+ import hashlib
15
16
import io
16
17
import os
17
18
import re
19
+ import shutil
18
20
import sys
19
21
from itertools import product
20
22
from multiprocessing .pool import ThreadPool
@@ -733,7 +735,29 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
733
735
input_text = input_text .replace ("uint16_t" , "int" )
734
736
return input_text
735
737
736
- def generateSPV (self , output_dir : str ) -> Dict [str , str ]:
738
+ def get_md5_checksum (self , file_path : str ) -> str :
739
+ # Use a reasonably sized buffer for better performance with large files
740
+ BUF_SIZE = 65536 # 64kb chunks
741
+
742
+ md5 = hashlib .md5 ()
743
+
744
+ with open (file_path , "rb" ) as f :
745
+ while True :
746
+ data = f .read (BUF_SIZE )
747
+ if not data :
748
+ break
749
+ md5 .update (data )
750
+
751
+ # Get the hexadecimal digest and compare
752
+ file_md5 = md5 .hexdigest ()
753
+ return file_md5
754
+
755
+ def generateSPV ( # noqa: C901
756
+ self ,
757
+ output_dir : str ,
758
+ cache_dir : Optional [str ] = None ,
759
+ force_rebuild : bool = False ,
760
+ ) -> Dict [str , str ]:
737
761
output_file_map = {}
738
762
739
763
def process_shader (shader_paths_pair ):
@@ -742,20 +766,48 @@ def process_shader(shader_paths_pair):
742
766
source_glsl = shader_paths_pair [1 ][0 ]
743
767
shader_params = shader_paths_pair [1 ][1 ]
744
768
769
+ glsl_out_path = os .path .join (output_dir , f"{ shader_name } .glsl" )
770
+ spv_out_path = os .path .join (output_dir , f"{ shader_name } .spv" )
771
+
772
+ if cache_dir is not None :
773
+ cached_source_glsl = os .path .join (
774
+ cache_dir , os .path .basename (source_glsl ) + ".t"
775
+ )
776
+ cached_glsl_out_path = os .path .join (cache_dir , f"{ shader_name } .glsl" )
777
+ cached_spv_out_path = os .path .join (cache_dir , f"{ shader_name } .spv" )
778
+ if (
779
+ not force_rebuild
780
+ and os .path .exists (cached_source_glsl )
781
+ and os .path .exists (cached_glsl_out_path )
782
+ and os .path .exists (cached_spv_out_path )
783
+ ):
784
+ current_checksum = self .get_md5_checksum (source_glsl )
785
+ cached_checksum = self .get_md5_checksum (cached_source_glsl )
786
+ # If the cached source GLSL template is the same as the current GLSL
787
+ # source file, then assume that the generated GLSL and SPIR-V will
788
+ # not have changed. In that case, just copy over the GLSL and SPIR-V
789
+ # files from the cache.
790
+ if current_checksum == cached_checksum :
791
+ shutil .copyfile (cached_spv_out_path , spv_out_path )
792
+ shutil .copyfile (cached_glsl_out_path , glsl_out_path )
793
+ return (spv_out_path , glsl_out_path )
794
+
745
795
with codecs .open (source_glsl , "r" , encoding = "utf-8" ) as input_file :
746
796
input_text = input_file .read ()
747
797
input_text = self .maybe_replace_u16vecn (input_text )
748
798
output_text = preprocess (input_text , shader_params )
749
799
750
- glsl_out_path = os .path .join (output_dir , f"{ shader_name } .glsl" )
751
800
with codecs .open (glsl_out_path , "w" , encoding = "utf-8" ) as output_file :
752
801
output_file .write (output_text )
753
802
803
+ if cache_dir is not None :
804
+ # Otherwise, store the generated and source GLSL files in the cache
805
+ shutil .copyfile (source_glsl , cached_source_glsl )
806
+ shutil .copyfile (glsl_out_path , cached_glsl_out_path )
807
+
754
808
# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
755
809
# This is mainly for testing purposes.
756
810
if self .glslc_path is not None :
757
- spv_out_path = os .path .join (output_dir , f"{ shader_name } .spv" )
758
-
759
811
cmd_base = [
760
812
self .glslc_path ,
761
813
"-fshader-stage=compute" ,
@@ -788,6 +840,9 @@ def process_shader(shader_paths_pair):
788
840
else :
789
841
raise RuntimeError (f"{ err_msg_base } { e .stderr } " ) from e
790
842
843
+ if cache_dir is not None :
844
+ shutil .copyfile (spv_out_path , cached_spv_out_path )
845
+
791
846
return (spv_out_path , glsl_out_path )
792
847
793
848
# Parallelize shader compilation as much as possible to optimize build time.
@@ -1089,8 +1144,11 @@ def main(argv: List[str]) -> int:
1089
1144
default = ["." ],
1090
1145
)
1091
1146
parser .add_argument ("-c" , "--glslc-path" , required = True , help = "" )
1092
- parser .add_argument ("-t" , "--tmp-dir-path" , required = True , help = "/tmp" )
1147
+ parser .add_argument (
1148
+ "-t" , "--tmp-dir-path" , required = True , help = "/tmp/vulkan_shaders/"
1149
+ )
1093
1150
parser .add_argument ("-o" , "--output-path" , required = True , help = "" )
1151
+ parser .add_argument ("-f" , "--force-rebuild" , action = "store_true" , default = False )
1094
1152
parser .add_argument ("--replace-u16vecn" , action = "store_true" , default = False )
1095
1153
parser .add_argument ("--optimize_size" , action = "store_true" , help = "" )
1096
1154
parser .add_argument ("--optimize" , action = "store_true" , help = "" )
@@ -1131,7 +1189,9 @@ def main(argv: List[str]) -> int:
1131
1189
glslc_flags = glslc_flags_str ,
1132
1190
replace_u16vecn = options .replace_u16vecn ,
1133
1191
)
1134
- output_spv_files = shader_generator .generateSPV (options .tmp_dir_path )
1192
+ output_spv_files = shader_generator .generateSPV (
1193
+ options .output_path , options .tmp_dir_path , options .force_rebuild
1194
+ )
1135
1195
1136
1196
genCppFiles (
1137
1197
output_spv_files ,
0 commit comments