Skip to content

Commit 1c4b62b

Browse files
Refactor some core stuff (#36539)
* some config changes * update * current state * update * update * updates and cleanup * something that works * fixup * fixes * nits * nit * nits and fix * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <[email protected]> * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <[email protected]> * cleanup * style * safe import * fix * updates * rename stuff an clean * style * small updates * ups * oups * nit * protect imports * update tp * rodfl * arf * turbo nit on init * fix import error * frumble gumbgle * try to fix the import error * should fix the non model test * update keep in float32 * update * fix * nits * fix subvconfigs * test was weird * nit * fix failing test * fix instruct blip * fixes * style * x.com * fix overwrite * ok last bit of failing test --------- Co-authored-by: Lysandre Debut <[email protected]>
1 parent e9756cd commit 1c4b62b

File tree

9 files changed

+704
-116
lines changed

9 files changed

+704
-116
lines changed

src/transformers/configuration_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -824,25 +824,27 @@ def to_diff_dict(self) -> Dict[str, Any]:
824824

825825
serializable_config_dict = {}
826826

827-
# only serialize values that differ from the default config
827+
# Only serialize values that differ from the default config,
828+
# except always keep the 'config' attribute.
828829
for key, value in config_dict.items():
829830
if (
830831
isinstance(getattr(self, key, None), PretrainedConfig)
831832
and key in class_config_dict
832833
and isinstance(class_config_dict[key], dict)
834+
or key in self.sub_configs
833835
):
834836
# For nested configs we need to clean the diff recursively
835-
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
837+
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
836838
if "model_type" in value:
837839
# Needs to be set even if it's not in the diff
838840
diff["model_type"] = value["model_type"]
839-
if len(diff) > 0:
840-
serializable_config_dict[key] = diff
841+
serializable_config_dict[key] = diff
841842
elif (
842843
key not in default_config_dict
843844
or key == "transformers_version"
845+
or key == "vocab_file"
844846
or value != default_config_dict[key]
845-
or (key in class_config_dict and value != class_config_dict[key])
847+
or (key in default_config_dict and value != class_config_dict.get(key, value))
846848
):
847849
serializable_config_dict[key] = value
848850

@@ -867,6 +869,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
867869
if "base_model_pp_plan" in serializable_config_dict:
868870
del serializable_config_dict["base_model_pp_plan"]
869871

872+
if "_name_or_path" in serializable_config_dict:
873+
del serializable_config_dict["_name_or_path"]
874+
870875
return serializable_config_dict
871876

872877
def to_dict(self) -> Dict[str, Any]:
@@ -1178,16 +1183,17 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
11781183
"""
11791184
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
11801185
values from `dict_a` that are different from values in `dict_b`.
1186+
1187+
dict_b : the default config dictionnary. We want to remove values that are in this one
11811188
"""
11821189
diff = {}
11831190
default = config_obj.__class__().to_dict() if config_obj is not None else {}
11841191
for key, value in dict_a.items():
11851192
obj_value = getattr(config_obj, str(key), None)
11861193
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
11871194
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1188-
if len(diff_value) > 0:
1189-
diff[key] = diff_value
1190-
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
1195+
diff[key] = diff_value
1196+
elif key not in dict_b or (value != default[key]):
11911197
diff[key] = value
11921198
return diff
11931199

src/transformers/integrations/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import TYPE_CHECKING
1515

16-
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
16+
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal
1717

1818

1919
_import_structure = {
@@ -128,6 +128,18 @@
128128
"convert_and_export_with_cache",
129129
]
130130

131+
try:
132+
if not is_torch_greater_or_equal("2.3"):
133+
raise OptionalDependencyNotAvailable()
134+
except OptionalDependencyNotAvailable:
135+
pass
136+
else:
137+
_import_structure["tensor_parallel"] = [
138+
"shard_and_distribute_module",
139+
"SUPPORTED_TP_STYLES",
140+
"translate_to_torch_parallel_style",
141+
]
142+
131143
if TYPE_CHECKING:
132144
from .aqlm import replace_with_aqlm_linear
133145
from .awq import (
@@ -231,6 +243,18 @@
231243
else:
232244
from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
233245

246+
try:
247+
if not is_torch_greater_or_equal("2.3"):
248+
raise OptionalDependencyNotAvailable()
249+
except OptionalDependencyNotAvailable:
250+
pass
251+
else:
252+
from .tensor_parallel import (
253+
SUPPORTED_TP_STYLES,
254+
shard_and_distribute_module,
255+
translate_to_torch_parallel_style,
256+
)
257+
234258
else:
235259
import sys
236260

0 commit comments

Comments
 (0)