Skip to content

Commit a2934b3

Browse files
committed
Format and warning
1 parent 79cc524 commit a2934b3

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5022,16 +5022,15 @@ def tensor_parallel(self, device_mesh):
50225022
device_mesh (`torch.distributed.DeviceMesh`):
50235023
The device mesh to use for tensor parallelism.
50245024
"""
5025+
50255026
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
50265027
# No op if `_tp_plan` attribute does not exist under the module.
50275028
# This is a helper function to be used with `model.apply` to recursively
50285029
# parallelize a model.
50295030
def tplize(mod: torch.nn.Module) -> None:
50305031
tp_plan = getattr(mod, "_tp_plan", None)
50315032
if tp_plan:
5032-
logger.debug(
5033-
f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}"
5034-
)
5033+
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
50355034
# In model configs, we use a neutral type (string) to specify
50365035
# parallel styles, here we translate them into torch TP types.
50375036
# Using tree_map because `tp_plan` is a dict.

src/transformers/models/llama/modeling_llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ def __init__(self, config: LlamaConfig):
813813

814814
self.gradient_checkpointing = False
815815
self._tp_plan = config._base_model_tp_plan
816+
if config.pretraining_tp != 1:
817+
logger.warn("`pretraining_tp` is deprecated, please use `tensor_parallel` method instead.")
816818
# Initialize weights and apply final processing
817819
self.post_init()
818820

src/transformers/pytorch_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,7 @@ def translate_to_torch_parallel_style(style: str):
340340
types.
341341
"""
342342
if not isinstance(style, str):
343-
raise ValueError(
344-
f"Unsupported parallel style type {type(style)}, expected str"
345-
)
343+
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
346344

347345
if style == "colwise":
348346
return ColwiseParallel()

0 commit comments

Comments
 (0)