Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,6 +3649,7 @@ def _read_from_stream(container: 'av.container.Container', start_offset: float,
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
tokenizer.processor = processor
if model is not None:
model.model.embed_tokens.register_forward_hook(_clone_hook)
model.model.embed_tokens.register_forward_hook(_output_device_map_hook)
return model, tokenizer

Expand Down
28 changes: 27 additions & 1 deletion swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import PreTrainedTokenizerBase, StoppingCriteria
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import strtobool

from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt
from swift.torchacc_utils import pad_and_split_batch
Expand Down Expand Up @@ -179,6 +180,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> b
return False


def is_deepspeed_enabled():
return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', 'False'))


class Template:
"""A template class for all supported models.

Expand Down Expand Up @@ -1504,8 +1509,29 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An

inputs['input_ids'] = input_ids
inputs['labels'] = labels
inputs['_data'] = {'plain_text': not images and not videos, 'input_ids': torch.tensor(input_ids)[None]}
return inputs, {}

def _post_encode(self, model, data: Any) -> Dict[str, Any]:
plain_text = data.pop('plain_text', False)
if is_deepspeed_enabled() and plain_text:
from PIL import Image
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
processor = self.tokenizer.processor
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
input_ids = data['input_ids']
device = input_ids.device
pixel_values = media_inputs['pixel_values'].to(device)
_model = model.model
if not hasattr(_model, 'embed_tokens'):
_model = _model.model # LoRA
inputs_embeds = _model.embed_tokens(input_ids)
pixel_values = pixel_values.type(model.visual.get_dtype())
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
inputs_embeds += image_embeds.mean() * 0.
return {'inputs_embeds': inputs_embeds[0]}
return {}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
for media_type in ['image', 'video']:
Expand Down Expand Up @@ -2150,7 +2176,7 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
vit_embeds = model.extract_feature(pixel_values).to(device=device)
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
elif is_deepspeed_zero3_enabled():
elif is_deepspeed_enabled():
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
inputs_embeds += vit_embeds.mean() * 0.
Expand Down