Skip to content

Commit 2c6b31c

Browse files
authored
FP16 optimizer automatically detect DeepSpeed compatibility (#18084)
### FP16 optimizer automatically detect DeepSpeed compatibility Optimum/Transformers are using accelerate lib to prepare models, so our FP16 optimizer wrapper does not work for long time. Because the namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`, which underlying is still calling into DeepSpeed stage1and2 optimizer. This PR includes following changes: 1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the modifier registry, plus a check on its contained `optimizer` property MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3 optimizer later) 2. For DeepSpeed version > 0.9.1, we will store the source code in a version list. As long as the related function in DeepSpeed remains unchanged during its new release, we won't need manually upgrade the version check any more. If some day, the source code did not match, a warning will be raised to users, to add a new version of source code in the list. With the above change, we will have our FP16 Optimizer working again in Optimum. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b)
1 parent ae85619 commit 2c6b31c

5 files changed

Lines changed: 223 additions & 31 deletions

File tree

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ exclude_patterns = [
4545
'cmake/external/**',
4646
# ignore generated flatbuffers code
4747
'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**',
48+
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
4849
]
4950
command = [
5051
'python',
@@ -76,6 +77,7 @@ exclude_patterns = [
7677
'cmake/**',
7778
'orttraining/*',
7879
'onnxruntime/core/flatbuffers/**',
80+
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
7981
]
8082
command = [
8183
'python',
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
#
5+
# Copyright 2020 The Microsoft DeepSpeed Team
6+
#
7+
# !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version,
8+
# It is used to compare with the source code of current installed DeepSpeed during runtime.
9+
# Please don't modify it or do any code formatting for it.
10+
# 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention.
11+
# --------------------------------------------------------------------------
12+
13+
# Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed.
14+
15+
class Stage1And2_DeepSpeedZeroOptimizer_0_9_2:
16+
17+
def has_overflow_serial(self, params, is_grad_list=False):
18+
for p in params:
19+
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
20+
return True
21+
22+
return False
23+
24+
25+
def get_grad_norm_direct(self, gradients, params, norm_type=2):
26+
"""Clips gradient norm of an iterable of parameters.
27+
28+
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
29+
added functionality to handle model parallel parameters. Note that
30+
the gradients are modified in place.
31+
32+
Arguments:
33+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
34+
single Tensor that will have gradients normalized
35+
max_norm (float or int): max norm of the gradients
36+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
37+
infinity norm.
38+
39+
Returns:
40+
Total norm of the parameters (viewed as a single vector).
41+
"""
42+
norm_type = float(norm_type)
43+
if norm_type == inf:
44+
total_norm = max(g.data.abs().max() for g in gradients)
45+
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
46+
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)
47+
48+
# Take max across all GPUs.
49+
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
50+
total_norm = total_norm_cuda[0].item()
51+
else:
52+
total_norm = 0.0
53+
# if dist.get_rank() == 0:
54+
# logger.info(f"Total Norm beginning {total_norm}")
55+
for g, p in zip(gradients, params):
56+
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
57+
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
58+
continue
59+
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
60+
param_norm = g.data.double().norm(2)
61+
total_norm += param_norm.item()**2
62+
# Sum across all model parallel GPUs.
63+
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
64+
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
65+
66+
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
67+
68+
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
69+
70+
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
71+
total_norm = -1
72+
73+
return total_norm
74+
75+
76+
def has_overflow_partitioned_grads_serial(self):
77+
for i in range(len(self.bit16_groups)):
78+
for j, grad in enumerate(self.averaged_gradients[i]):
79+
if grad is not None and self._has_inf_or_nan(grad.data, j):
80+
return True
81+
return False

orttraining/orttraining/python/training/optim/_ds_modifier.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,112 @@
1010
# - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799
1111
# --------------------------------------------------------------------------
1212

13+
from __future__ import annotations
14+
15+
import inspect
1316
import types
1417
import warnings
1518

1619
import torch
1720
from numpy import inf
1821
from packaging.version import Version
1922

23+
from ._ds_code_store import Stage1And2_DeepSpeedZeroOptimizer_0_9_2
2024
from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads
2125
from ._multi_tensor_apply import MultiTensorApply
2226

2327
multi_tensor_applier = MultiTensorApply(2048 * 32)
2428

2529

30+
def _get_normalized_str(function) -> str:
31+
return inspect.getsource(function)
32+
33+
34+
def _dynamic_checks(cur_ds_version: Version, optimizer) -> bool:
35+
_functions_to_override = ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"]
36+
37+
_version_to_source_code_map = {"0.9.2": Stage1And2_DeepSpeedZeroOptimizer_0_9_2}
38+
39+
# Try to find the biggest version that is smaller than or equal to cur_ds_version.
40+
# then compare the source code (in case the found version is the latest version supported);
41+
# If current code does not match the found version, return False, and raise a warning to
42+
# add the new version to the list.
43+
versions = [Version(v) for v in _version_to_source_code_map]
44+
sorted_versions = sorted(versions, reverse=True)
45+
version_to_compare = None
46+
for sv in sorted_versions:
47+
if cur_ds_version >= sv:
48+
version_to_compare = sv
49+
break
50+
51+
if version_to_compare is None:
52+
warnings.warn(
53+
"Unable to find a DeepSpeed version that is smaller than or equal to the current version "
54+
f"{cur_ds_version}. Skip modifying optimizer.",
55+
UserWarning,
56+
)
57+
return False
58+
59+
v_optimizer_cls = _version_to_source_code_map[str(version_to_compare)]
60+
all_match = True
61+
for func_name in _functions_to_override:
62+
if not getattr(optimizer, func_name):
63+
warnings.warn(
64+
f"DeepSpeed function {func_name} is not found in optimizer. Skip modifying optimizer.", UserWarning
65+
)
66+
all_match = False
67+
cur_code_str = _get_normalized_str(getattr(optimizer, func_name))
68+
v_code_str = _get_normalized_str(getattr(v_optimizer_cls, func_name))
69+
if cur_code_str != v_code_str:
70+
warnings.warn(
71+
f"DeepSpeed function {func_name} has changed after version {version_to_compare}. "
72+
f"Please append new version {cur_ds_version} in _version_to_source_code_map and _ds_code_store.py.\n"
73+
f"---[{func_name}] Old Source Code Start----\n"
74+
f"{v_code_str}\n"
75+
f"---{func_name} Old Source Code End----\n"
76+
f"---[{func_name}] New Source Code Start----\n"
77+
f"{cur_code_str}\n"
78+
f"---{func_name} New Source Code End----",
79+
UserWarning,
80+
)
81+
all_match = False
82+
83+
return all_match
84+
85+
2686
class DeepSpeedZeROModifier(FP16OptimizerModifier):
2787
def __init__(self, optimizer, **kwargs) -> None:
2888
super().__init__(optimizer)
2989

3090
def can_be_modified(self):
3191
import deepspeed
3292

93+
# Note 1:
3394
# This modifier relies on the implementation of has_overflow_serial, get_grad_norm_direct,
3495
# and has_overflow_partitioned_grads_serial
3596
# in https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py.
36-
# Everytime if we want to update this version supporting list to a newer version,
37-
# we need to check if the implementation of these functions are changed.
38-
# An easy way to check is to check the history of this file, if there is no change during the update,
97+
# The minimum version supported is 0.4.0, all versions in between [0.4.0, 0.9.1]
98+
# are manually checked to make sure the implementation of these functions are "logically" not changed.
99+
# The way we did the check is to check the history of this file, if there is no change during the update,
39100
# it's safe to update the version supporting list. Otherwise, or the file is moved or renamed,
40101
# we need to check the implementation of these functions in detail.
102+
#
103+
# Note 2:
104+
# Since version 0.9.2, we added dynamic source code check, by comparing installed version of code with
105+
# the source code in our code store. If the source code is changed, we will raise a warning to ask user
106+
# to add the new version to the code store. Otherwise, we will override the functions.
107+
41108
ds_version = Version(deepspeed.__version__)
42-
if ds_version > Version("0.9.1") or ds_version < Version("0.4.0"):
109+
if ds_version < Version("0.4.0"):
110+
warnings.warn(
111+
f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}, "
112+
"minimum supported version: 0.4.0, current version",
113+
UserWarning,
114+
)
115+
return False
116+
if ds_version > Version("0.9.1") and not _dynamic_checks(ds_version, self._optimizer):
43117
warnings.warn(
44-
"Skip modifying optimizer because of unsupported DeepSpeed version {}, "
45-
"supported version: 0.4.0 - 0.9.1.".format(deepspeed.__version__),
118+
f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}.",
46119
UserWarning,
47120
)
48121
return False

orttraining/orttraining/python/training/optim/_modifier_registry.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,59 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6+
from __future__ import annotations
7+
8+
import warnings
9+
from typing import ClassVar
10+
611
from ._apex_amp_modifier import ApexAMPModifier
712
from ._ds_modifier import DeepSpeedZeROModifier
813
from ._megatron_modifier import LegacyMegatronLMModifier
14+
from ._modifier import FP16OptimizerModifier
15+
16+
17+
class _AccelerateDeepSpeedZeROModifier(DeepSpeedZeROModifier):
18+
"""
19+
Modifier for wrapper of DeepSpeed Optimizer in accelerator.
20+
https://github.com/huggingface/accelerate/blob/7843286f2e1c50735d259fbc0084a7f1c85e00e3/src/accelerate/utils/deepspeed.py#L182C19-L182C19
21+
"""
22+
23+
def __init__(self, accelerator_optimizer, **kwargs) -> None:
24+
super().__init__(accelerator_optimizer.optimizer)
25+
26+
27+
def get_full_qualified_type_name(o):
28+
klass = o.__class__
29+
module = klass.__module__
30+
if module == "builtins":
31+
return klass.__qualname__
32+
return module + "." + klass.__qualname__
33+
34+
35+
class OptimizerModifierTypeRegistry:
36+
_MAP: ClassVar[dict[str, FP16OptimizerModifier]] = {
37+
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
38+
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
39+
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
40+
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
41+
}
42+
43+
@staticmethod
44+
def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None:
45+
"""Create modifier for optimizer."""
46+
if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP:
47+
return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs)
48+
49+
if optimizer_full_qualified_name == "accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper":
50+
if (
51+
hasattr(optimizer, "optimizer")
52+
and get_full_qualified_type_name(optimizer.optimizer) in OptimizerModifierTypeRegistry._MAP
53+
):
54+
return _AccelerateDeepSpeedZeROModifier(optimizer, **kwargs)
955

10-
OptimizerModifierTypeRegistry = {
11-
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
12-
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
13-
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
14-
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
15-
}
56+
warnings.warn(
57+
"Skip modifying optimizer because of optimizer name not found in the registry: "
58+
f"{optimizer_full_qualified_name}",
59+
UserWarning,
60+
)
61+
return None

orttraining/orttraining/python/training/optim/fp16_optimizer.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6-
import warnings
76

8-
from ._modifier_registry import OptimizerModifierTypeRegistry
7+
from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name
98

109

1110
def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
@@ -80,22 +79,13 @@ def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
8079
8180
"""
8281

83-
def get_full_qualified_type_name(o):
84-
if hasattr(optimizer, "_amp_stash"):
85-
return "apex.amp.optimizer.unique_name_as_id"
86-
87-
klass = o.__class__
88-
module = klass.__module__
89-
if module == "builtins":
90-
return klass.__qualname__
91-
return module + "." + klass.__qualname__
92-
93-
optimizer_full_qualified_name = get_full_qualified_type_name(optimizer)
94-
if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry:
95-
warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning)
96-
return optimizer
97-
98-
modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)
99-
modifier.apply()
82+
optimizer_full_qualified_name = (
83+
"apex.amp.optimizer.unique_name_as_id"
84+
if hasattr(optimizer, "_amp_stash")
85+
else get_full_qualified_type_name(optimizer)
86+
)
87+
modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs)
88+
if modifier is not None:
89+
modifier.apply()
10090

10191
return optimizer

0 commit comments

Comments
 (0)