Skip to content

Commit 4a79bf9

Browse files
Fix some bug for finetune and batch infer For GLM-4.1V (#39090)
* update * 1
1 parent 2100ee6 commit 4a79bf9

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

src/transformers/models/glm4v/image_processing_glm4v_fast.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _preprocess(
121121
do_convert_rgb: bool,
122122
input_data_format: Optional[Union[str, ChannelDimension]],
123123
device: Optional[Union[str, torch.device]],
124+
disable_grouping: Optional[bool],
124125
):
125126
"""
126127
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
@@ -173,7 +174,7 @@ def _preprocess(
173174
resized_height, resized_width = height, width
174175

175176
# Group images by size for batched resizing
176-
grouped_images, grouped_images_index = group_images_by_shape(images)
177+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
177178
resized_images_grouped = {}
178179
for shape, stacked_images in grouped_images.items():
179180
if do_resize:
@@ -191,7 +192,7 @@ def _preprocess(
191192
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
192193
# Group images by size for further processing
193194
# Needed in case do_resize is False, or resize returns images with different sizes
194-
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
195+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
195196
processed_images_grouped = {}
196197
for shape, stacked_images in grouped_images.items():
197198
# Fused rescale and normalize
@@ -249,6 +250,7 @@ def preprocess(
249250
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
250251
input_data_format: Optional[Union[str, ChannelDimension]] = None,
251252
device: Optional["torch.device"] = None,
253+
disable_grouping: Optional[bool] = False,
252254
**kwargs,
253255
):
254256
r"""
@@ -323,6 +325,7 @@ def preprocess(
323325
do_convert_rgb=do_convert_rgb,
324326
input_data_format=input_data_format,
325327
device=device,
328+
disable_grouping=disable_grouping,
326329
)
327330
pixel_values.extend(patches)
328331
vision_grid_thws.append(image_grid_thw)
@@ -351,11 +354,11 @@ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=Non
351354

352355
factor = patch_size * merge_size
353356
resized_height, resized_width = smart_resize(
354-
t=self.temporal_patch_size,
357+
num_frames=self.temporal_patch_size,
355358
height=height,
356359
width=width,
360+
temporal_factor=self.temporal_patch_size,
357361
factor=factor,
358-
t_factor=self.temporal_patch_size,
359362
)
360363
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
361364
return grid_h * grid_w

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def __init__(self, config: Glm4vVisionConfig) -> None:
287287
self.attention_dropout = config.attention_dropout
288288
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
289289
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
290+
self.is_causal = False
290291

291292
def forward(
292293
self,
@@ -324,7 +325,7 @@ def forward(
324325
attention_mask,
325326
dropout=0.0 if not self.training else self.attention_dropout,
326327
scaling=self.scale,
327-
is_causal=False,
328+
is_causal=self.is_causal,
328329
**kwargs,
329330
)
330331
attn_output = attn_output.squeeze(0)
@@ -1016,7 +1017,7 @@ def get_rope_index(
10161017
dtype=input_ids.dtype,
10171018
device=input_ids.device,
10181019
)
1019-
1020+
image_index, video_index = 0, 0
10201021
attention_mask = attention_mask.to(total_input_ids.device)
10211022
for i, input_ids in enumerate(total_input_ids):
10221023
input_ids = input_ids[attention_mask[i] == 1]
@@ -1046,7 +1047,6 @@ def get_rope_index(
10461047

10471048
llm_pos_ids_list = []
10481049
video_frame_num = 1
1049-
image_index, video_index = 0, 0
10501050

10511051
for modality_type, start_idx, end_idx in input_type_group:
10521052
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
@@ -1088,9 +1088,7 @@ def get_rope_index(
10881088
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
10891089

10901090
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
1091-
10921091
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
1093-
10941092
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
10951093

10961094
video_index += 1

src/transformers/models/glm4v/modular_glm4v.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def __init__(self, config: Glm4vVisionConfig) -> None:
516516
self.attention_dropout = config.attention_dropout
517517
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
518518
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
519+
self.is_causal = False
519520

520521
def forward(
521522
self,
@@ -553,7 +554,7 @@ def forward(
553554
attention_mask,
554555
dropout=0.0 if not self.training else self.attention_dropout,
555556
scaling=self.scale,
556-
is_causal=False,
557+
is_causal=self.is_causal,
557558
**kwargs,
558559
)
559560
attn_output = attn_output.squeeze(0)
@@ -1115,7 +1116,7 @@ def get_rope_index(
11151116
dtype=input_ids.dtype,
11161117
device=input_ids.device,
11171118
)
1118-
1119+
image_index, video_index = 0, 0
11191120
attention_mask = attention_mask.to(total_input_ids.device)
11201121
for i, input_ids in enumerate(total_input_ids):
11211122
input_ids = input_ids[attention_mask[i] == 1]
@@ -1145,7 +1146,6 @@ def get_rope_index(
11451146

11461147
llm_pos_ids_list = []
11471148
video_frame_num = 1
1148-
image_index, video_index = 0, 0
11491149

11501150
for modality_type, start_idx, end_idx in input_type_group:
11511151
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
@@ -1187,9 +1187,7 @@ def get_rope_index(
11871187
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
11881188

11891189
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
1190-
11911190
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
1192-
11931191
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
11941192

11951193
video_index += 1

0 commit comments

Comments
 (0)