mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 2ad2ea283e
into 37a239ca50
This commit is contained in:
commit
7e0280250f
@ -1053,6 +1053,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
video_group_index = 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1082,7 +1083,6 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
|
||||
@ -1126,7 +1126,11 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
video_group_index += 1
|
||||
|
||||
if video_group_index >= video_grid_thw[video_index][0]:
|
||||
video_index += 1
|
||||
video_group_index = 0
|
||||
|
||||
video_frame_num += 1
|
||||
|
||||
@ -1175,7 +1179,13 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
||||
temp_frames_hw = []
|
||||
for t, h, w in video_grid_thw:
|
||||
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
||||
temp_frames_hw.append(repeated_row)
|
||||
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
||||
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
video_embeds = torch.split(video_embeds, split_sizes)
|
||||
return video_embeds
|
||||
|
@ -1087,6 +1087,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
video_group_index = 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1116,7 +1117,6 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
|
||||
@ -1160,7 +1160,11 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
video_group_index += 1
|
||||
|
||||
if video_group_index >= video_grid_thw[video_index][0]:
|
||||
video_index += 1
|
||||
video_group_index = 0
|
||||
|
||||
video_frame_num += 1
|
||||
|
||||
@ -1196,6 +1200,30 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input videos.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
||||
temp_frames_hw = []
|
||||
for t, h, w in video_grid_thw:
|
||||
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
||||
temp_frames_hw.append(repeated_row)
|
||||
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
||||
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
video_embeds = torch.split(video_embeds, split_sizes)
|
||||
return video_embeds
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
@ -1687,32 +1715,38 @@ class Glm4vProcessor(Qwen2_5_VLProcessor):
|
||||
video_index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
num_frames = len(video_grid_thw)
|
||||
num_frames = video_grid_thw[video_index][0]
|
||||
video_structure = ""
|
||||
|
||||
if hasattr(timestamps, "tolist"):
|
||||
timestamps_list = timestamps.tolist()[0]
|
||||
else:
|
||||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
||||
|
||||
unique_timestamps = []
|
||||
for idx in range(0, len(timestamps_list)):
|
||||
unique_timestamps.append(timestamps_list[idx])
|
||||
|
||||
selected_timestamps = unique_timestamps[:num_frames]
|
||||
while len(selected_timestamps) < num_frames:
|
||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
timestamp_sec = selected_timestamps[frame_idx]
|
||||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
||||
video_structure += frame_structure
|
||||
|
||||
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
||||
num_image_tokens = (
|
||||
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
|
||||
)
|
||||
for frame_idx in range(num_frames):
|
||||
if self.image_token in text[i]:
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
|
||||
video_index += 1
|
||||
|
||||
for frame_idx in range(len(video_grid_thw)):
|
||||
if self.image_token in text[i]:
|
||||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||
|
@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin):
|
||||
video_index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
num_frames = len(video_grid_thw)
|
||||
num_frames = video_grid_thw[video_index][0]
|
||||
video_structure = ""
|
||||
|
||||
if hasattr(timestamps, "tolist"):
|
||||
timestamps_list = timestamps.tolist()[0]
|
||||
else:
|
||||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
||||
|
||||
unique_timestamps = []
|
||||
for idx in range(0, len(timestamps_list)):
|
||||
unique_timestamps.append(timestamps_list[idx])
|
||||
|
||||
selected_timestamps = unique_timestamps[:num_frames]
|
||||
while len(selected_timestamps) < num_frames:
|
||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
timestamp_sec = selected_timestamps[frame_idx]
|
||||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
||||
video_structure += frame_structure
|
||||
|
||||
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
||||
num_image_tokens = (
|
||||
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
|
||||
)
|
||||
for frame_idx in range(num_frames):
|
||||
if self.image_token in text[i]:
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
|
||||
video_index += 1
|
||||
|
||||
for frame_idx in range(len(video_grid_thw)):
|
||||
if self.image_token in text[i]:
|
||||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||
|
@ -246,10 +246,6 @@ class Glm4vVideoProcessor(BaseVideoProcessor):
|
||||
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
||||
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
||||
video_grid_thw = torch.tensor(processed_grids)
|
||||
total_frames = video_grid_thw[0][0].item()
|
||||
h = video_grid_thw[0][1].item()
|
||||
w = video_grid_thw[0][2].item()
|
||||
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
|
||||
data = {
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_grid_thw": video_grid_thw,
|
||||
|
Loading…
Reference in New Issue
Block a user