Skip to content

Commit cc365b6

Browse files
authored
Add sharding for speechlm and vlm (#11876)
* Add sharding for speechlm and vlm Signed-off-by: Boxiang Wang <boxiangw@nvidia.com> * Add ci test for VLM Signed-off-by: Boxiang Wang <boxiangw@nvidia.com> * Apply isort and black reformatting Signed-off-by: BoxiangW <BoxiangW@users.noreply.github.com> --------- Signed-off-by: Boxiang Wang <boxiangw@nvidia.com> Signed-off-by: BoxiangW <BoxiangW@users.noreply.github.com> Co-authored-by: BoxiangW <BoxiangW@users.noreply.github.com>
1 parent 7e24313 commit cc365b6

File tree

5 files changed

+198
-15
lines changed

5 files changed

+198
-15
lines changed

.github/workflows/cicd-main.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,6 +3637,7 @@ jobs:
36373637
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft_hf.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --strategy fsdp --devices 2
36383638
AFTER_SCRIPT: |
36393639
rm -rf nemo_experiments
3640+
36403641
L2_VLM_HF_Transformer_PEFT_4bit:
36413642
needs: [ cicd-test-container-setup ]
36423643
uses: ./.github/workflows/_test_template.yml
@@ -3648,6 +3649,17 @@ jobs:
36483649
AFTER_SCRIPT: |
36493650
rm -rf nemo_experiments
36503651
3652+
L2_VLM_HF_Transformer_SFT_FSDP2:
3653+
needs: [ cicd-test-container-setup ]
3654+
uses: ./.github/workflows/_test_template.yml
3655+
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_VLM_HF_Transformer_SFT_FSDP2') || needs.cicd-test-container-setup.outputs.all == 'true'
3656+
with:
3657+
RUNNER: self-hosted-azure-gpus-1
3658+
SCRIPT: |
3659+
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/sft_fsdp2.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3
3660+
AFTER_SCRIPT: |
3661+
rm -rf nemo_experiments
3662+
36513663
L2_HF_Transformer_PEFT:
36523664
needs: [ cicd-test-container-setup ]
36533665
uses: ./.github/workflows/_test_template.yml
@@ -5092,6 +5104,7 @@ jobs:
50925104
- L2_VLM_HF_Transformer_PEFT
50935105
- L2_VLM_HF_Transformer_PEFT_FSDP
50945106
- L2_VLM_HF_Transformer_PEFT_4bit
5107+
- L2_VLM_HF_Transformer_SFT_FSDP2
50955108
- L2_HF_Transformer_SFT_2gpu_nemorun
50965109
- L2_HF_Transformer_SFT_TE_Acceleration
50975110
- L2_HF_Transformer_PT

nemo/collections/speechlm/models/hf_auto_model_for_speech_seq2seq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
2121
from nemo.collections.llm import fn
2222
from nemo.lightning import io
23+
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
2324
from nemo.utils import logging
2425

2526

@@ -94,6 +95,10 @@ def configure_model(self, train=True):
9495
config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code)
9596
self.model = AutoModelForSpeechSeq2Seq.from_config(config, trust_remote_code=self.trust_remote_code)
9697

98+
# Apply FSDP2 and TP to the model
99+
if self.device_mesh is not None:
100+
fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh, model_type="speech_seq2seq")
101+
97102
if train:
98103
self.model.train()
99104

@@ -104,7 +109,7 @@ def forward(self, input_features, decoder_input_ids, attention_mask=None):
104109
decoder_input_ids=decoder_input_ids,
105110
)
106111

107-
def training_step(self, batch):
112+
def training_step(self, batch, batch_idx=None):
108113
outputs = self.forward(input_features=batch["input_features"], decoder_input_ids=batch["decoder_input_ids"])
109114
loss_mask = batch.get('loss_mask', None)
110115
if loss_mask is not None:

nemo/collections/vlm/hf/model/hf_auto_model_for_image_text_to_text.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from nemo.collections.llm import fn
2121
from nemo.lightning import io
22+
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
2223
from nemo.utils import logging
2324

2425

@@ -95,13 +96,18 @@ def configure_model(self):
9596
self.model = AutoModelForImageTextToText.from_config(
9697
config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code
9798
)
99+
100+
# Apply FSDP2 and TP to the model
101+
if self.device_mesh is not None:
102+
fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh)
103+
98104
self.model.train()
99105

100106
def forward(self, batch):
101107
"""Runs forward with the model"""
102108
return self.model(**batch)
103109

104-
def training_step(self, batch):
110+
def training_step(self, batch, batch_idx=None):
105111
"""Run one training step"""
106112
labels = batch.pop('labels').to(self.model.device)
107113
loss_mask = batch.pop('loss_mask', None)

nemo/lightning/pytorch/strategies/utils.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def _convert(state_dict, k, sh_key, v, prepend_offsets, prefix="", allow_shape_m
345345

346346
# Taken and modified from torchtitan
347347
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
348-
def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None):
348+
def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None, model_type: str = None):
349349
"""Apply parallelisms and activation checkpointing to the model.
350350
NOTE: The passed-in model preferably should be on meta device. Otherwise,
351351
the model must fit on GPU or CPU memory.
@@ -364,18 +364,43 @@ def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None):
364364
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
365365

366366
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
367-
for layer_id, transformer_block in enumerate(model.model.layers):
368-
# Apply activation checkpointing
369-
# transformer_block = checkpoint_wrapper(transformer_block)
370-
# As an optimization, do not reshard after forward for the last
371-
# transformer block since FSDP would prefetch it immediately
372-
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
373-
fully_shard(
374-
transformer_block,
375-
**fsdp_config,
376-
reshard_after_forward=reshard_after_forward,
377-
)
378-
model.model.layers[layer_id] = transformer_block
367+
if model_type == "speech_seq2seq":
368+
for layer_id, transformer_block in enumerate(model.model.encoder.layers):
369+
# Apply activation checkpointing
370+
# transformer_block = checkpoint_wrapper(transformer_block)
371+
# As an optimization, do not reshard after forward for the last
372+
# transformer block since FSDP would prefetch it immediately
373+
reshard_after_forward = int(layer_id) < len(model.model.encoder.layers) - 1
374+
fully_shard(
375+
transformer_block,
376+
**fsdp_config,
377+
reshard_after_forward=reshard_after_forward,
378+
)
379+
model.model.encoder.layers[layer_id] = transformer_block
380+
381+
for layer_id, transformer_block in enumerate(model.model.decoder.layers):
382+
# transformer_block = checkpoint_wrapper(transformer_block)
383+
reshard_after_forward = int(layer_id) < len(model.model.decoder.layers) - 1
384+
fully_shard(
385+
transformer_block,
386+
**fsdp_config,
387+
reshard_after_forward=reshard_after_forward,
388+
)
389+
model.model.decoder.layers[layer_id] = transformer_block
390+
else:
391+
for layer_id, transformer_block in enumerate(model.model.layers):
392+
# Apply activation checkpointing
393+
# transformer_block = checkpoint_wrapper(transformer_block)
394+
# As an optimization, do not reshard after forward for the last
395+
# transformer block since FSDP would prefetch it immediately
396+
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
397+
fully_shard(
398+
transformer_block,
399+
**fsdp_config,
400+
reshard_after_forward=reshard_after_forward,
401+
)
402+
model.model.layers[layer_id] = transformer_block
403+
379404
model = fully_shard(model, **fsdp_config)
380405

381406
return model
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from importlib.metadata import version
16+
17+
import fiddle as fdl
18+
import torch
19+
from lightning.pytorch.loggers import WandbLogger
20+
from packaging.version import Version as PkgVersion
21+
22+
from nemo import lightning as nl
23+
from nemo.collections import llm, vlm
24+
25+
DATA_PATH = "/home/TestData/vlm/rdr-items"
26+
27+
28+
def get_torch_version_str():
29+
import torch
30+
31+
if hasattr(torch, '__version__'):
32+
return str(torch.__version__)
33+
else:
34+
return version("torch")
35+
36+
37+
def mk_hf_vlm_dataset(processor, mbs, gbs):
38+
skipped_tokens = vlm.HFAutoModelForImageTextToText.extract_skipped_token_ids(processor)
39+
40+
def collate_fn(examples, processor):
41+
def fmt(sample):
42+
instruction = "Describe accurately the given image."
43+
conversation = [
44+
{
45+
"role": "user",
46+
"content": [{"type": "text", "text": instruction}, {"type": "image", "image": sample["image"]}],
47+
},
48+
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
49+
]
50+
return {"conversation": conversation, "images": [sample['image']]}
51+
52+
text = []
53+
images = []
54+
for example in map(fmt, examples):
55+
text.append(
56+
processor.apply_chat_template(
57+
example["conversation"],
58+
tokenize=False,
59+
add_generation_prompt=False,
60+
)
61+
)
62+
images += example['images']
63+
64+
# Tokenize the text and process the images
65+
batch = processor(
66+
text=text,
67+
images=images,
68+
padding=True,
69+
return_tensors="pt",
70+
)
71+
72+
batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
73+
74+
labels = batch["input_ids"].clone()
75+
labels[torch.isin(labels, skipped_tokens)] = -100
76+
batch["labels"] = labels
77+
return batch
78+
79+
return vlm.HFDatasetDataModule(
80+
DATA_PATH,
81+
split="train[:10]",
82+
micro_batch_size=mbs,
83+
global_batch_size=gbs,
84+
collate_fn=lambda x: collate_fn(x, processor=processor),
85+
)
86+
87+
88+
if __name__ == '__main__':
89+
if PkgVersion(get_torch_version_str()) >= PkgVersion("2.4"):
90+
import argparse
91+
92+
parser = argparse.ArgumentParser()
93+
parser.add_argument('--model', default='Qwen/Qwen2-VL-2B-Instruct')
94+
parser.add_argument('--devices', default=2)
95+
parser.add_argument('--mbs', default=1)
96+
parser.add_argument('--gbs', default=1)
97+
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
98+
parser.add_argument('--max-steps', type=int, default=100)
99+
parser.add_argument('--wandb-project', type=str, default=None)
100+
parser.add_argument('--disable-ckpt', action='store_false')
101+
parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true")
102+
args = parser.parse_args()
103+
104+
wandb = None
105+
if args.wandb_project is not None:
106+
model = '_'.join(args.model.split('/')[-2:])
107+
wandb = WandbLogger(
108+
project=args.wandb_project,
109+
name=f'{model}_dev{args.devices}_strat_fsdp2',
110+
)
111+
grad_clip = None
112+
use_dist_samp = False
113+
processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model)
114+
115+
llm.api.finetune(
116+
model=vlm.HFAutoModelForImageTextToText(args.model, load_in_4bit=args.use_4bit),
117+
data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs),
118+
trainer=nl.Trainer(
119+
devices=args.devices,
120+
max_steps=args.max_steps,
121+
accelerator=args.accelerator,
122+
strategy=nl.FSDP2Strategy(data_parallel_size=2, tensor_parallel_size=1),
123+
log_every_n_steps=1,
124+
limit_val_batches=0.0,
125+
num_sanity_val_steps=0,
126+
accumulate_grad_batches=10,
127+
gradient_clip_val=grad_clip,
128+
use_distributed_sampler=use_dist_samp,
129+
logger=wandb,
130+
enable_checkpointing=args.disable_ckpt,
131+
),
132+
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
133+
log=None,
134+
)

0 commit comments

Comments
 (0)