108108 save_fsdp_optimizer ,
109109 wait_for_everyone ,
110110)
111- from .utils .constants import FSDP_PYTORCH_VERSION , PROFILE_PATTERN_NAME , BETA_TP_AVAILABLE_PYTORCH_VERSION
111+ from .utils .constants import (
112+ BETA_TP_AVAILABLE_PYTORCH_VERSION ,
113+ BETA_TP_AVAILABLE_TRANSFORMERS_VERSION ,
114+ FSDP_PYTORCH_VERSION ,
115+ PROFILE_PATTERN_NAME ,
116+ )
112117from .utils .modeling import get_state_dict_offloaded_model
113118from .utils .other import is_compiled_module
114119
@@ -359,10 +364,15 @@ def __init__(
359364 if not is_torch_version (">=" , FSDP_PYTORCH_VERSION ):
360365 raise ValueError (f"FSDP requires PyTorch >= { FSDP_PYTORCH_VERSION } " )
361366
362- if os .environ .get ("ACCELERATE_USE_TP" , "false" ) == "true" or isinstance (torch_tp_plugin , TorchTensorParallelPlugin ):
367+ if os .environ .get ("ACCELERATE_USE_TP" , "false" ) == "true" or isinstance (
368+ torch_tp_plugin , TorchTensorParallelPlugin
369+ ):
363370 if not is_torch_version (">=" , BETA_TP_AVAILABLE_PYTORCH_VERSION ):
364371 raise ValueError (f"TP requires PyTorch >= { BETA_TP_AVAILABLE_PYTORCH_VERSION } " )
365372
373+ if not compare_versions ("transformers" , ">=" , BETA_TP_AVAILABLE_TRANSFORMERS_VERSION ):
374+ raise ValueError (f"TP requires transformers >= { BETA_TP_AVAILABLE_TRANSFORMERS_VERSION } " )
375+
366376 if fsdp_plugin is None : # init from env variables
367377 fsdp_plugin = (
368378 FullyShardedDataParallelPlugin () if os .environ .get ("ACCELERATE_USE_FSDP" , "false" ) == "true" else None
@@ -373,12 +383,14 @@ def __init__(
373383 os .environ ["ACCELERATE_USE_FSDP" ] = "true" # use FSDP if plugin is provided
374384
375385 if torch_tp_plugin is None :
376- torch_tp_plugin = (TorchTensorParallelPlugin () if os .environ .get ("ACCELERATE_USE_TP" , "false" ) == "true" else None )
386+ torch_tp_plugin = (
387+ TorchTensorParallelPlugin () if os .environ .get ("ACCELERATE_USE_TP" , "false" ) == "true" else None
388+ )
377389 else :
378390 if not isinstance (torch_tp_plugin , TorchTensorParallelPlugin ):
379391 raise TypeError ("`torch_tp_plugin` must be a TorchTensorParallelPlugin object." )
380392 os .environ ["ACCELERATE_USE_TP" ] = "true"
381-
393+
382394 if megatron_lm_plugin is None : # init from env variables
383395 megatron_lm_plugin = (
384396 MegatronLMPlugin () if os .environ .get ("ACCELERATE_USE_MEGATRON_LM" , "false" ) == "true" else None
@@ -1489,8 +1501,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
14891501 if self .ddp_handler is not None :
14901502 self .ddp_handler .register_comm_hook (model )
14911503 elif self .distributed_type == DistributedType .TP :
1492- if not model .supports_tp_plan :
1493- raise NotImplementedError ("Provided model does not support tensor parallelism" )
1504+ if hasattr (model , "supports_tp_plan" ) and not model .supports_tp_plan :
1505+ if not compare_versions ("transformers" , ">=" , BETA_TP_AVAILABLE_TRANSFORMERS_VERSION ):
1506+ raise ValueError (f"TP requires transformers >= { BETA_TP_AVAILABLE_TRANSFORMERS_VERSION } " )
1507+ raise NotImplementedError (
1508+ "Provided model does not support tensor parallelism. \
1509+ Tensor parallelism plan can be added as base_model_tp_plan to model config class \
1510+ and _tp_plan attribute to model class."
1511+ )
14941512 model .tensor_parallel (self .state .torch_tp_plugin .torch_device_mesh ["tp" ])
14951513 elif self .distributed_type == DistributedType .FSDP :
14961514 # We need to fix the optimizer *before* sharding the model
0 commit comments