From adc82c85c29dc7a389135282f1b60bc1e23a371a Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 30 Jun 2025 17:47:41 +0000 Subject: [PATCH 1/5] changes for video --- .../models/glm4v/modeling_glm4v.py | 17 +++++++++++++---- .../models/glm4v/processing_glm4v.py | 18 ++++++++++++------ .../models/glm4v/video_processing_glm4v.py | 4 ---- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 65ec7f0b79c..608104dde13 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -185,7 +185,6 @@ class Glm4vVisionEmbeddings(nn.Module): .unsqueeze(0) .to(device=device, dtype=torch.float32) ) - # Calculate target dimensions for each patch target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( device=device, dtype=torch.float32 @@ -1018,6 +1017,7 @@ class Glm4vModel(Glm4vPreTrainedModel): device=input_ids.device, ) image_index, video_index = 0, 0 + video_group_index = 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] @@ -1047,7 +1047,6 @@ class Glm4vModel(Glm4vPreTrainedModel): llm_pos_ids_list = [] video_frame_num = 1 - for modality_type, start_idx, end_idx in input_type_group: st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 @@ -1091,7 +1090,11 @@ class Glm4vModel(Glm4vPreTrainedModel): w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) - video_index += 1 + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 video_frame_num += 1 @@ -1140,7 +1143,13 @@ class Glm4vModel(Glm4vPreTrainedModel): The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(video_embeds, split_sizes) return video_embeds diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 5a0f5d94d81..7f44dd572dd 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[video_index][0] video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps + unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) + selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure text[i] = text[i].replace(self.video_token, video_structure, 1) + + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + video_index += 1 - for frame_idx in range(len(video_grid_thw)): - if self.image_token in text[i]: - num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) diff --git a/src/transformers/models/glm4v/video_processing_glm4v.py b/src/transformers/models/glm4v/video_processing_glm4v.py index ac6a9921078..2d10f9fa74b 100644 --- a/src/transformers/models/glm4v/video_processing_glm4v.py +++ b/src/transformers/models/glm4v/video_processing_glm4v.py @@ -246,10 +246,6 @@ class Glm4vVideoProcessor(BaseVideoProcessor): processed_grids = reorder_videos(processed_grids, grouped_videos_index) pixel_values_videos = torch.cat(processed_videos, dim=0) video_grid_thw = torch.tensor(processed_grids) - total_frames = video_grid_thw[0][0].item() - h = video_grid_thw[0][1].item() - w = video_grid_thw[0][2].item() - video_grid_thw = [[1, h, w] for _ in range(total_frames)] data = { "pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw, From b72947176337d9afa3098ce147dc2ba1979ac331 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Wed, 2 Jul 2025 10:45:41 +0000 Subject: [PATCH 2/5] update modular --- .../models/glm4v/modeling_glm4v.py | 9 ++------- src/transformers/models/glm4v/modular_glm4v.py | 8 ++++++-- .../models/glm4v/processing_glm4v.py | 18 ++++++------------ 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index eadb1849f5c..c17965421ef 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -185,6 +185,7 @@ class Glm4vVisionEmbeddings(nn.Module): .unsqueeze(0) .to(device=device, dtype=torch.float32) ) + # Calculate target dimensions for each patch target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( device=device, dtype=torch.float32 @@ -1178,13 +1179,7 @@ class Glm4vModel(Glm4vPreTrainedModel): The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - for t, h, w in video_grid_thw: - repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) - video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(video_embeds, split_sizes) return video_embeds diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index cf4a6b9233f..e331d688d36 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1087,6 +1087,7 @@ class Glm4vModel(Qwen2_5_VLModel): device=input_ids.device, ) image_index, video_index = 0, 0 + video_group_index = 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] @@ -1116,7 +1117,6 @@ class Glm4vModel(Qwen2_5_VLModel): llm_pos_ids_list = [] video_frame_num = 1 - for modality_type, start_idx, end_idx in input_type_group: st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 @@ -1160,7 +1160,11 @@ class Glm4vModel(Qwen2_5_VLModel): w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) - video_index += 1 + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 video_frame_num += 1 diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 7f44dd572dd..5a0f5d94d81 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -167,38 +167,32 @@ class Glm4vProcessor(ProcessorMixin): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = video_grid_thw[video_index][0] + num_frames = len(video_grid_thw) video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps - unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) - selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) - for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure text[i] = text[i].replace(self.video_token, video_structure, 1) - - num_image_tokens = ( - video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] - ) - for frame_idx in range(num_frames): - if self.image_token in text[i]: - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) - video_index += 1 + for frame_idx in range(len(video_grid_thw)): + if self.image_token in text[i]: + num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) From 5df38281d37d809dc62b13108e21d8bd3a67cc47 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Wed, 2 Jul 2025 10:55:02 +0000 Subject: [PATCH 3/5] change get_video_features --- src/transformers/models/glm4v/modeling_glm4v.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index c17965421ef..5e47e06799e 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1179,7 +1179,13 @@ class Glm4vModel(Glm4vPreTrainedModel): The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(video_embeds, split_sizes) return video_embeds From 454b4a39f449d84c7497870012d897bbf8889e7c Mon Sep 17 00:00:00 2001 From: Kingsley Date: Wed, 2 Jul 2025 11:47:06 +0000 Subject: [PATCH 4/5] update video token replacement --- .../models/glm4v/processing_glm4v.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 5a0f5d94d81..50408caa3f1 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[video_index][0] video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps + unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) + selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + video_index += 1 - for frame_idx in range(len(video_grid_thw)): - if self.image_token in text[i]: - num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) From 2ad2ea283eeadeb244cdeb6769a1edffa1b8df37 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Wed, 2 Jul 2025 17:20:56 +0000 Subject: [PATCH 5/5] update modular --- .../models/glm4v/modular_glm4v.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index e331d688d36..cac82a30f75 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1200,6 +1200,30 @@ class Glm4vModel(Qwen2_5_VLModel): return position_ids, mrope_position_deltas + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + @auto_docstring @can_return_tuple def forward( @@ -1691,32 +1715,38 @@ class Glm4vProcessor(Qwen2_5_VLProcessor): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[video_index][0] video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps + unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) + selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + video_index += 1 - for frame_idx in range(len(video_grid_thw)): - if self.image_token in text[i]: - num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])