@@ -62,7 +62,7 @@ def is_compiled_module(module: torch.nn.Module) -> bool:
6262
6363def has_compiled_regions (module : torch .nn .Module ) -> bool :
6464 """
65- Check whether the module has submodules that were compiled with torch.compile()
65+ Check whether the module has submodules that were compiled with ` torch.compile()`.
6666 """
6767 if not hasattr (torch , "_dynamo" ):
6868 return False
@@ -75,6 +75,29 @@ def has_compiled_regions(module: torch.nn.Module) -> bool:
7575 return False
7676
7777
78+ def is_repeated_blocks (module : torch .nn .Module ) -> bool :
79+ """
80+ Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
81+ is useful to determine whether we should apply regional compilation to the module.
82+ """
83+
84+ return isinstance (module , torch .nn .ModuleList ) and all (isinstance (m , module [0 ].__class__ ) for m in module )
85+
86+
87+ def has_repeated_blocks (module : torch .nn .Module ) -> bool :
88+ """
89+ Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
90+ any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
91+ module.
92+ """
93+ if module ._modules :
94+ for submodule in module .modules ():
95+ if is_repeated_blocks (submodule ):
96+ return True
97+
98+ return False
99+
100+
78101def compile_regions (module : torch .nn .Module , ** compile_kwargs ) -> torch .nn .Module :
79102 """
80103 Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
@@ -123,33 +146,54 @@ def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Modul
123146 """
124147
125148 def _compile_regions (module : torch .nn .Module , ** compile_kwargs ) -> torch .nn .Module :
126- if isinstance (module , torch .nn .ModuleList ):
127- if all (isinstance (submodule , module [0 ].__class__ ) for submodule in module ):
128- new_module = torch .nn .ModuleList ()
129- for submodule in module :
130- new_module .append (torch .compile (submodule , ** compile_kwargs ))
131- else :
132- new_module = torch .compile (module , ** compile_kwargs )
133- elif module ._modules : # Non-leaf node
149+ if is_repeated_blocks (module ):
150+ new_module = torch .nn .ModuleList ()
151+ for submodule in module :
152+ new_module .append (torch .compile (submodule , ** compile_kwargs ))
153+ elif has_repeated_blocks (module ):
134154 new_module = module .__class__ .__new__ (module .__class__ )
135155 new_module .__dict__ .update (module .__dict__ )
136156 new_module ._modules = {}
137157 for name , submodule in module .named_children ():
138158 new_module .add_module (name , _compile_regions (submodule , ** compile_kwargs ))
139- else : # Leaf node
159+ else :
140160 new_module = torch .compile (module , ** compile_kwargs )
141161
142162 return new_module
143163
144164 new_module = _compile_regions (module , ** compile_kwargs )
145165
146- if not hasattr ( new_module , "_orig_mod" ) :
166+ if "_orig_mod" not in new_module . __dict__ :
147167 # Keeps a reference to the original module to decompile/unwrap it later
148168 new_module .__dict__ ["_orig_mod" ] = module
149169
150170 return new_module
151171
152172
173+ def compile_regions_deepspeed (module : torch .nn .Module , ** compile_kwargs ):
174+ """
175+ Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
176+ Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
177+ `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
178+ instead.
179+
180+ Args:
181+ module (`torch.nn.Module`):
182+ The model to compile.
183+ **compile_kwargs:
184+ Additional keyword arguments to pass to `module.compile()`.
185+ """
186+
187+ if is_repeated_blocks (module ):
188+ for submodule in module :
189+ submodule .compile (** compile_kwargs )
190+ elif has_repeated_blocks (module ):
191+ for child in module .children ():
192+ compile_regions_deepspeed (child , ** compile_kwargs )
193+ else : # leaf node
194+ module .compile (** compile_kwargs )
195+
196+
153197def extract_model_from_parallel (
154198 model , keep_fp32_wrapper : bool = True , keep_torch_compile : bool = True , recursive : bool = False
155199):
@@ -175,9 +219,12 @@ def extract_model_from_parallel(
175219 is_compiled = is_compiled_module (model )
176220 has_compiled = has_compiled_regions (model )
177221
178- if is_compiled or has_compiled :
222+ if is_compiled :
179223 compiled_model = model
180224 model = model ._orig_mod
225+ elif has_compiled :
226+ compiled_model = model
227+ model = model .__dict__ ["_orig_mod" ]
181228
182229 if is_deepspeed_available ():
183230 from deepspeed import DeepSpeedEngine
@@ -221,9 +268,13 @@ def _recursive_unwrap(module):
221268 if getattr (model , "_converted_to_transformer_engine" , False ):
222269 convert_model (model , to_transformer_engine = False )
223270
224- if keep_torch_compile and (is_compiled or has_compiled ):
225- compiled_model ._orig_mod = model
226- model = compiled_model
271+ if keep_torch_compile :
272+ if is_compiled :
273+ compiled_model ._orig_mod = model
274+ model = compiled_model
275+ elif has_compiled :
276+ compiled_model .__dict__ ["_orig_mod" ] = model
277+ model = compiled_model
227278
228279 return model
229280
0 commit comments