Skip to content

Commit bc65f3f

Browse files
authored
[modular] Do not track imports in functions (#36279)
* Add check * just check for function * Update examples
1 parent 4b5cf54 commit bc65f3f

10 files changed

+82
-33
lines changed

examples/modular-transformers/configuration_my_new_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ class MyNewModelConfig(PretrainedConfig):
140140
"layers.*.mlp.up_proj": "colwise",
141141
"layers.*.mlp.down_proj": "rowwise",
142142
}
143+
base_model_pp_plan = {
144+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
145+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
146+
"norm": (["hidden_states"], ["hidden_states"]),
147+
}
143148

144149
def __init__(
145150
self,

examples/modular-transformers/configuration_my_new_model2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class MyNewModel2Config(PretrainedConfig):
4343
"layers.*.mlp.up_proj": "colwise",
4444
"layers.*.mlp.down_proj": "rowwise",
4545
}
46+
base_model_pp_plan = {
47+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
48+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
49+
"norm": (["hidden_states"], ["hidden_states"]),
50+
}
4651

4752
def __init__(
4853
self,

examples/modular-transformers/configuration_new_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ class NewModelConfig(PretrainedConfig):
7979

8080
model_type = "new_model"
8181
keys_to_ignore_at_inference = ["past_key_values"]
82+
base_model_tp_plan = {
83+
"layers.*.self_attn.q_proj": "colwise",
84+
"layers.*.self_attn.k_proj": "colwise",
85+
"layers.*.self_attn.v_proj": "colwise",
86+
"layers.*.self_attn.o_proj": "rowwise",
87+
"layers.*.mlp.gate_proj": "colwise",
88+
"layers.*.mlp.up_proj": "colwise",
89+
"layers.*.mlp.down_proj": "rowwise",
90+
}
91+
base_model_pp_plan = {
92+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
93+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
94+
"norm": (["hidden_states"], ["hidden_states"]),
95+
}
8296

8397
def __init__(
8498
self,

examples/modular-transformers/image_processing_new_imgproc_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
PILImageResampling,
2020
infer_channel_dimension_format,
2121
is_scaled_image,
22-
make_list_of_images,
22+
make_flat_list_of_images,
2323
to_numpy_array,
2424
valid_images,
2525
validate_preprocess_arguments,
@@ -221,8 +221,7 @@ def preprocess(
221221

222222
size = size if size is not None else self.size
223223
size = get_size_dict(size, default_to_square=False)
224-
225-
images = make_list_of_images(images)
224+
images = make_flat_list_of_images(images)
226225

227226
if not valid_images(images):
228227
raise ValueError(

examples/modular-transformers/modeling_dummy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ class DummyPreTrainedModel(PreTrainedModel):
356356
_supports_cache_class = True
357357
_supports_quantized_cache = True
358358
_supports_static_cache = True
359+
_supports_attention_backend = True
359360

360361
def _init_weights(self, module):
361362
std = self.config.initializer_range
@@ -698,7 +699,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
698699
if attention_mask is not None:
699700
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
700701
mask_length = attention_mask.shape[-1]
701-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
702+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
703+
causal_mask.device
704+
)
702705
padding_mask = padding_mask == 0
703706
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
704707
padding_mask, min_dtype

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel):
356356
_supports_cache_class = True
357357
_supports_quantized_cache = True
358358
_supports_static_cache = True
359+
_supports_attention_backend = True
359360

360361
def _init_weights(self, module):
361362
std = self.config.initializer_range
@@ -698,7 +699,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
698699
if attention_mask is not None:
699700
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
700701
mask_length = attention_mask.shape[-1]
701-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
702+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
703+
causal_mask.device
704+
)
702705
padding_mask = padding_mask == 0
703706
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
704707
padding_mask, min_dtype

examples/modular-transformers/modeling_my_new_model2.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
356356
_supports_cache_class = True
357357
_supports_quantized_cache = True
358358
_supports_static_cache = True
359+
_supports_attention_backend = True
359360

360361
def _init_weights(self, module):
361362
std = self.config.initializer_range
@@ -491,6 +492,7 @@ def forward(
491492
output_hidden_states: Optional[bool] = None,
492493
return_dict: Optional[bool] = None,
493494
cache_position: Optional[torch.LongTensor] = None,
495+
**kwargs, # NOOP kwarg for now
494496
) -> Union[Tuple, BaseModelOutputWithPast]:
495497
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
496498
output_hidden_states = (
@@ -703,7 +705,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
703705
if attention_mask is not None:
704706
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
705707
mask_length = attention_mask.shape[-1]
706-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
708+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
709+
causal_mask.device
710+
)
707711
padding_mask = padding_mask == 0
708712
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
709713
padding_mask, min_dtype
@@ -787,17 +791,20 @@ def forward(
787791
if self.config.pad_token_id is None and batch_size != 1:
788792
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
789793
if self.config.pad_token_id is None:
790-
sequence_lengths = -1
794+
last_non_pad_token = -1
795+
elif input_ids is not None:
796+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
797+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
798+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
799+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
791800
else:
792-
if input_ids is not None:
793-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
794-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
795-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
796-
sequence_lengths = sequence_lengths.to(logits.device)
797-
else:
798-
sequence_lengths = -1
801+
last_non_pad_token = -1
802+
logger.warning_once(
803+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
804+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
805+
)
799806

800-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
807+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
801808

802809
loss = None
803810
if labels is not None:

examples/modular-transformers/modeling_new_task_model.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
add_start_docstrings_to_model_forward,
2020
replace_return_docstrings,
2121
)
22+
from ...utils.deprecation import deprecate_kwarg
2223
from ..auto import AutoModel, AutoModelForCausalLM
2324
from .configuration_new_task_model import NewTaskModelConfig
2425

@@ -254,8 +255,7 @@ def _update_causal_mask(
254255
token_type_ids,
255256
past_key_values,
256257
cache_position,
257-
input_ids=None,
258-
inputs_embeds=None,
258+
input_tensor,
259259
is_training: bool = False,
260260
):
261261
if self.config.text_config._attn_implementation == "flash_attention_2":
@@ -265,8 +265,7 @@ def _update_causal_mask(
265265

266266
using_static_cache = isinstance(past_key_values, StaticCache)
267267
min_dtype = torch.finfo(self.dtype).min
268-
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
269-
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
268+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
270269
if using_static_cache:
271270
target_length = past_key_values.get_max_cache_shape()
272271
elif isinstance(past_key_values, HybridCache):
@@ -297,16 +296,20 @@ def _update_causal_mask(
297296
if attention_mask is not None:
298297
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
299298
mask_length = attention_mask.shape[-1]
299+
300+
# First unmask prefix tokens during training
301+
if is_training:
302+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
303+
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
304+
)
305+
306+
# Then apply padding mask (will mask pad tokens)
300307
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
301308
padding_mask = padding_mask == 0
302309
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
303310
padding_mask, min_dtype
304311
)
305-
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
306-
if is_training:
307-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
308-
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
309-
)
312+
310313
return causal_mask
311314

312315
def get_image_features(self, pixel_values: torch.FloatTensor):
@@ -325,6 +328,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor):
325328
image_features = image_features / (self.config.text_config.hidden_size**0.5)
326329
return image_features
327330

331+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
328332
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
329333
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
330334
def forward(
@@ -351,10 +355,12 @@ def forward(
351355
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
352356
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
353357
354-
num_logits_to_keep (`int`, *optional*):
355-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
358+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
359+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
356360
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
357361
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
362+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
363+
This is useful when using packed tensor format (single dimension for batch and sequence length).
358364
359365
Returns:
360366
@@ -418,7 +424,7 @@ def prepare_inputs_for_generation(
418424
attention_mask=None,
419425
token_type_ids=None,
420426
use_cache=True,
421-
num_logits_to_keep=None,
427+
logits_to_keep=None,
422428
labels=None,
423429
**kwargs,
424430
):
@@ -431,7 +437,7 @@ def prepare_inputs_for_generation(
431437
position_ids=position_ids,
432438
cache_position=cache_position,
433439
use_cache=use_cache,
434-
num_logits_to_keep=num_logits_to_keep,
440+
logits_to_keep=logits_to_keep,
435441
token_type_ids=token_type_ids,
436442
**kwargs,
437443
)
@@ -445,10 +451,12 @@ def prepare_inputs_for_generation(
445451
model_inputs["pixel_values"] = pixel_values
446452
is_training = token_type_ids is not None and labels is not None
447453
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
454+
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
448455
causal_mask = self._update_causal_mask(
449-
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
456+
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
450457
)
451458
model_inputs["attention_mask"] = causal_mask
459+
452460
return model_inputs
453461

454462
def resize_token_embeddings(

examples/modular-transformers/modeling_super.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ class SuperPreTrainedModel(PreTrainedModel):
356356
_supports_cache_class = True
357357
_supports_quantized_cache = True
358358
_supports_static_cache = True
359+
_supports_attention_backend = True
359360

360361
def _init_weights(self, module):
361362
std = self.config.initializer_range
@@ -620,7 +621,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
620621
if attention_mask is not None:
621622
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
622623
mask_length = attention_mask.shape[-1]
623-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
624+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
625+
causal_mask.device
626+
)
624627
padding_mask = padding_mask == 0
625628
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
626629
padding_mask, min_dtype

utils/modular_model_converter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,11 @@ def leave_FunctionDef(self, node):
649649
self.current_function = None
650650

651651
def visit_If(self, node):
652-
for stmt in node.body.body:
653-
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
654-
self.imports.append(node)
652+
# If we are inside a function, do not add the import to the list of imports
653+
if self.current_function is None:
654+
for stmt in node.body.body:
655+
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
656+
self.imports.append(node)
655657

656658
def visit_ClassDef(self, node: ClassDef) -> None:
657659
"""Record class nodes to create their dependencies at the end."""

0 commit comments

Comments
 (0)