Skip to content

Commit 6b61a37

Browse files
fix deepspeed regional compilation (#3609)
1 parent 682691d commit 6b61a37

File tree

4 files changed

+77
-22
lines changed

4 files changed

+77
-22
lines changed

src/accelerate/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
PROFILE_PATTERN_NAME,
126126
)
127127
from .utils.modeling import get_state_dict_offloaded_model
128-
from .utils.other import compile_regions, is_compiled_module
128+
from .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module
129129

130130

131131
if is_deepspeed_available():
@@ -2030,7 +2030,7 @@ def _prepare_deepspeed(self, *args):
20302030
if compare_versions("deepspeed", ">=", "0.14.4") and self.state.dynamo_plugin.backend != DynamoBackend.NO:
20312031
compile_kwargs = self.state.dynamo_plugin.to_kwargs()
20322032
if self.state.dynamo_plugin.use_regional_compilation:
2033-
engine.module = compile_regions(engine.module, **compile_kwargs)
2033+
compile_regions_deepspeed(engine.module, **compile_kwargs)
20342034
else:
20352035
engine.compile(backend=compile_kwargs.pop("backend"), compile_kwargs=compile_kwargs)
20362036
if optimizer is not None:

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@
270270
check_os_kernel,
271271
clean_state_dict_for_safetensors,
272272
compile_regions,
273+
compile_regions_deepspeed,
273274
convert_bytes,
274275
extract_model_from_parallel,
275276
get_module_children_bottom_up,

src/accelerate/utils/other.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def is_compiled_module(module: torch.nn.Module) -> bool:
6262

6363
def has_compiled_regions(module: torch.nn.Module) -> bool:
6464
"""
65-
Check whether the module has submodules that were compiled with torch.compile()
65+
Check whether the module has submodules that were compiled with `torch.compile()`.
6666
"""
6767
if not hasattr(torch, "_dynamo"):
6868
return False
@@ -75,6 +75,29 @@ def has_compiled_regions(module: torch.nn.Module) -> bool:
7575
return False
7676

7777

78+
def is_repeated_blocks(module: torch.nn.Module) -> bool:
79+
"""
80+
Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
81+
is useful to determine whether we should apply regional compilation to the module.
82+
"""
83+
84+
return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
85+
86+
87+
def has_repeated_blocks(module: torch.nn.Module) -> bool:
88+
"""
89+
Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
90+
any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
91+
module.
92+
"""
93+
if module._modules:
94+
for submodule in module.modules():
95+
if is_repeated_blocks(submodule):
96+
return True
97+
98+
return False
99+
100+
78101
def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
79102
"""
80103
Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
@@ -123,33 +146,54 @@ def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Modul
123146
"""
124147

125148
def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
126-
if isinstance(module, torch.nn.ModuleList):
127-
if all(isinstance(submodule, module[0].__class__) for submodule in module):
128-
new_module = torch.nn.ModuleList()
129-
for submodule in module:
130-
new_module.append(torch.compile(submodule, **compile_kwargs))
131-
else:
132-
new_module = torch.compile(module, **compile_kwargs)
133-
elif module._modules: # Non-leaf node
149+
if is_repeated_blocks(module):
150+
new_module = torch.nn.ModuleList()
151+
for submodule in module:
152+
new_module.append(torch.compile(submodule, **compile_kwargs))
153+
elif has_repeated_blocks(module):
134154
new_module = module.__class__.__new__(module.__class__)
135155
new_module.__dict__.update(module.__dict__)
136156
new_module._modules = {}
137157
for name, submodule in module.named_children():
138158
new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))
139-
else: # Leaf node
159+
else:
140160
new_module = torch.compile(module, **compile_kwargs)
141161

142162
return new_module
143163

144164
new_module = _compile_regions(module, **compile_kwargs)
145165

146-
if not hasattr(new_module, "_orig_mod"):
166+
if "_orig_mod" not in new_module.__dict__:
147167
# Keeps a reference to the original module to decompile/unwrap it later
148168
new_module.__dict__["_orig_mod"] = module
149169

150170
return new_module
151171

152172

173+
def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
174+
"""
175+
Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
176+
Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
177+
`torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
178+
instead.
179+
180+
Args:
181+
module (`torch.nn.Module`):
182+
The model to compile.
183+
**compile_kwargs:
184+
Additional keyword arguments to pass to `module.compile()`.
185+
"""
186+
187+
if is_repeated_blocks(module):
188+
for submodule in module:
189+
submodule.compile(**compile_kwargs)
190+
elif has_repeated_blocks(module):
191+
for child in module.children():
192+
compile_regions_deepspeed(child, **compile_kwargs)
193+
else: # leaf node
194+
module.compile(**compile_kwargs)
195+
196+
153197
def extract_model_from_parallel(
154198
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
155199
):
@@ -175,9 +219,12 @@ def extract_model_from_parallel(
175219
is_compiled = is_compiled_module(model)
176220
has_compiled = has_compiled_regions(model)
177221

178-
if is_compiled or has_compiled:
222+
if is_compiled:
179223
compiled_model = model
180224
model = model._orig_mod
225+
elif has_compiled:
226+
compiled_model = model
227+
model = model.__dict__["_orig_mod"]
181228

182229
if is_deepspeed_available():
183230
from deepspeed import DeepSpeedEngine
@@ -221,9 +268,13 @@ def _recursive_unwrap(module):
221268
if getattr(model, "_converted_to_transformer_engine", False):
222269
convert_model(model, to_transformer_engine=False)
223270

224-
if keep_torch_compile and (is_compiled or has_compiled):
225-
compiled_model._orig_mod = model
226-
model = compiled_model
271+
if keep_torch_compile:
272+
if is_compiled:
273+
compiled_model._orig_mod = model
274+
model = compiled_model
275+
elif has_compiled:
276+
compiled_model.__dict__["_orig_mod"] = model
277+
model = compiled_model
227278

228279
return model
229280

tests/test_compile.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from torch.utils.benchmark import Timer
1818

19-
from accelerate.test_utils import require_huggingface_suite, require_non_cpu, require_non_hpu, torch_device
19+
from accelerate.test_utils import require_huggingface_suite, require_non_cpu, require_non_hpu, slow, torch_device
2020
from accelerate.utils import compile_regions, extract_model_from_parallel, release_memory
2121

2222

@@ -58,6 +58,8 @@ def test_regions_are_compiled(self):
5858
# Check that the compiled_model.transformer.h[i] and compiled_model.lm_head are compiled separately
5959
assert isinstance(compiled_model.transformer.h[0], torch._dynamo.eval_frame.OptimizedModule)
6060
assert isinstance(compiled_model.lm_head, torch._dynamo.eval_frame.OptimizedModule)
61+
assert compiled_model.transformer.h[0]._orig_mod is model.transformer.h[0]
62+
assert compiled_model.lm_head._orig_mod is model.lm_head
6163

6264
def test_extract_model_keep_torch_compile(self):
6365
model, _ = self._get_model_and_inputs()
@@ -84,14 +86,14 @@ def test_extract_model_remove_torch_compile(self):
8486
def test_regional_compilation_cold_start(self):
8587
model, input_ids = self._get_model_and_inputs()
8688

87-
regional_compilation_model = compile_regions(model, mode="reduce-overhead", backend=backend)
89+
regional_compilation_model = compile_regions(model, backend=backend)
8890
regional_compilation_cold_start = (
8991
Timer(stmt=COMPILE_STMT, globals={"model": regional_compilation_model, "input_ids": input_ids})
9092
.timeit(COMPILE_ITERS)
9193
.median
9294
)
9395

94-
full_compilation_model = torch.compile(model, mode="reduce-overhead", backend=backend)
96+
full_compilation_model = torch.compile(model, backend=backend)
9597
full_compilation_cold_start = (
9698
Timer(stmt=COMPILE_STMT, globals={"model": full_compilation_model, "input_ids": input_ids})
9799
.timeit(COMPILE_ITERS)
@@ -106,6 +108,7 @@ def test_regional_compilation_cold_start(self):
106108

107109
release_memory(model, full_compilation_model, regional_compilation_model)
108110

111+
@slow
109112
@require_non_cpu
110113
@require_huggingface_suite
111114
def test_regional_compilation_inference_speedup(self):
@@ -115,14 +118,14 @@ def test_regional_compilation_inference_speedup(self):
115118
Timer(stmt=INFRENCE_STMT, globals={"model": model, "input_ids": input_ids}).timeit(INFERENCE_ITERS).median
116119
)
117120

118-
regional_compilation_model = compile_regions(model, mode="reduce-overhead", backend=backend)
121+
regional_compilation_model = compile_regions(model, backend=backend)
119122
regional_compilation_inference_latency = (
120123
Timer(stmt=INFRENCE_STMT, globals={"model": regional_compilation_model, "input_ids": input_ids})
121124
.timeit(INFERENCE_ITERS)
122125
.median
123126
)
124127

125-
full_compilation_model = torch.compile(model, mode="reduce-overhead", backend=backend)
128+
full_compilation_model = torch.compile(model, backend=backend)
126129
full_compilation_inference_latency = (
127130
Timer(stmt=INFRENCE_STMT, globals={"model": full_compilation_model, "input_ids": input_ids})
128131
.timeit(INFERENCE_ITERS)

0 commit comments

Comments
 (0)