update video token replacement

This commit is contained in:
Kingsley 2025-07-02 11:47:06 +00:00
parent 5df38281d3
commit 454b4a39f4

View File

@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin):
video_index = 0 video_index = 0
for i in range(len(text)): for i in range(len(text)):
while self.video_token in text[i]: while self.video_token in text[i]:
num_frames = len(video_grid_thw) num_frames = video_grid_thw[video_index][0]
video_structure = "" video_structure = ""
if hasattr(timestamps, "tolist"): if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0] timestamps_list = timestamps.tolist()[0]
else: else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
unique_timestamps = [] unique_timestamps = []
for idx in range(0, len(timestamps_list)): for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx]) unique_timestamps.append(timestamps_list[idx])
selected_timestamps = unique_timestamps[:num_frames] selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames: while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
for frame_idx in range(num_frames): for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx] timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure video_structure += frame_structure
text[i] = text[i].replace(self.video_token, video_structure, 1) text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
video_index += 1 video_index += 1
for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token) text[i] = text[i].replace("<|placeholder|>", self.image_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])