mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Chat template: update for processor (#35953)
* update * we need batched nested input to always process correctly * update a bit * fix copies
This commit is contained in:
parent
5bd7694781
commit
eebd2c972c
@ -562,7 +562,7 @@ def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] =
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
||||
def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
"""
|
||||
Decode the video with open-cv decoder.
|
||||
|
||||
@ -570,13 +570,25 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
"""
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||
|
||||
index = 0
|
||||
@ -595,7 +607,7 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def read_video_decord(video_path: str, num_frames: Optional[int] = None):
|
||||
def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
"""
|
||||
Decode the video with Decord decoder.
|
||||
|
||||
@ -603,18 +615,31 @@ def read_video_decord(video_path: str, num_frames: Optional[int] = None):
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
"""
|
||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||
indices = get_uniform_frame_indices(total_num_frames=len(vr), num_frames=num_frames)
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
return frames
|
||||
|
||||
|
||||
def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
||||
def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
"""
|
||||
Decode the video with PyAV decoder.
|
||||
|
||||
@ -622,15 +647,26 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
"""
|
||||
container = av.open(video_path)
|
||||
|
||||
# sample uniformly "num_frames" frames from the video
|
||||
total_num_frames = container.streams.video[0].frames
|
||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||
|
||||
frames = []
|
||||
@ -644,7 +680,7 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
||||
|
||||
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
||||
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
"""
|
||||
Decode the video with torchvision decoder.
|
||||
|
||||
@ -652,7 +688,11 @@ def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
@ -664,6 +704,15 @@ def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
||||
pts_unit="sec",
|
||||
output_format="TCHW",
|
||||
)
|
||||
video_fps = info["video_fps"]
|
||||
total_num_frames = video.size(0) - 1
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
)
|
||||
|
||||
if num_frames is not None:
|
||||
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
|
||||
@ -680,7 +729,12 @@ VIDEO_DECODERS = {
|
||||
}
|
||||
|
||||
|
||||
def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None, backend: str = "opencv") -> np.array:
|
||||
def load_video(
|
||||
video: Union[str, "VideoInput"],
|
||||
num_frames: Optional[int] = None,
|
||||
fps: Optional[int] = None,
|
||||
backend: str = "opencv",
|
||||
) -> np.array:
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
@ -689,12 +743,19 @@ def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||
|
||||
Returns:
|
||||
`np.array`: A numpy array of shape (num_frames, channels, height, width).
|
||||
"""
|
||||
|
||||
if fps is not None and num_frames is not None:
|
||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||
|
||||
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
@ -735,7 +796,7 @@ def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video = video_decoder(file_obj)
|
||||
video = video_decoder(file_obj, num_frames=num_frames, fps=fps)
|
||||
return video
|
||||
|
||||
|
||||
|
@ -110,6 +110,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
The resampling filter to use if resizing the image.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_mean: List[float] = None,
|
||||
|
@ -499,6 +499,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
The resampling filter to use if resizing the image.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_mean: List[float] = None,
|
||||
@ -997,6 +999,10 @@ class AriaProcessor(ProcessorMixin):
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
|
||||
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when remocing
|
||||
# otherwise `self.image_processor.model_input_names` is also modified
|
||||
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
|
@ -158,6 +158,10 @@ class AriaProcessor(ProcessorMixin):
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
|
||||
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when remocing
|
||||
# otherwise `self.image_processor.model_input_names` is also modified
|
||||
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
|
@ -132,7 +132,7 @@ class Emu3ImageProcessor(BaseImageProcessor):
|
||||
The spatial downsample factor the image will be downsampled in feature extracting phase
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
model_input_names = ["pixel_values", "image_sizes"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -63,6 +63,7 @@ class Emu3Processor(ProcessorMixin):
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
||||
image_processor_class = "Emu3ImageProcessor"
|
||||
|
||||
@ -179,7 +180,7 @@ class Emu3Processor(ProcessorMixin):
|
||||
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
data.update(**image_features)
|
||||
|
||||
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
|
||||
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].pop("return_tensors", None))
|
||||
|
||||
def calculate_generate_size(self, ratio, image_area, spatial_factor):
|
||||
width, height = map(int, ratio.split(":"))
|
||||
|
@ -184,7 +184,7 @@ class Idefics2ImageProcessor(BaseImageProcessor):
|
||||
strategy was first introduced in https://arxiv.org/abs/2311.06607.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -289,7 +289,7 @@ class Idefics3ImageProcessor(BaseImageProcessor):
|
||||
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -129,6 +129,7 @@ class Idefics3Processor(ProcessorMixin):
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["image_seq_len", "chat_template"]
|
||||
image_processor_class = "Idefics3ImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
@ -354,7 +355,7 @@ class Idefics3Processor(ProcessorMixin):
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
||||
|
||||
|
||||
__all__ = ["Idefics3Processor"]
|
||||
|
@ -163,7 +163,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
model_input_names = ["pixel_values", "image_sizes"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -202,14 +202,17 @@ class MllamaProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
image_processor_class = "MllamaImageProcessor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
def __init__(self, image_processor, tokenizer, chat_template=None):
|
||||
if not hasattr(tokenizer, "image_token"):
|
||||
self.image_token = "<|image|>"
|
||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
@ -220,8 +223,7 @@ class MllamaProcessor(ProcessorMixin):
|
||||
self.python_token = "<|python_tag|>"
|
||||
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
||||
self.bos_token = tokenizer.bos_token
|
||||
self.chat_template = tokenizer.chat_template
|
||||
super().__init__(image_processor, tokenizer)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -364,6 +366,10 @@ class MllamaProcessor(ProcessorMixin):
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
|
||||
# Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when remocing
|
||||
# otherwise `self.image_processor.model_input_names` is also modified
|
||||
image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"]
|
||||
return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
|
||||
|
||||
|
||||
|
@ -379,6 +379,9 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
||||
The backend to use when loading the video which will be used only when there are videos in the conversation.
|
||||
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
|
||||
that supports all types of sources to load from.
|
||||
video_fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
"""
|
||||
|
||||
tokenize: Optional[bool] = False
|
||||
@ -390,6 +393,7 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
||||
return_assistant_tokens_mask: Optional[bool] = False
|
||||
num_frames: Optional[int] = None
|
||||
video_load_backend: Optional[str] = "pyav"
|
||||
video_fps: Optional[int] = None
|
||||
|
||||
|
||||
class AllKwargsForChatTemplate(
|
||||
@ -762,7 +766,11 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
||||
# However, for models added in the future, we won't get the expected error if this file is missing.
|
||||
if resolved_processor_file is None:
|
||||
return {}, kwargs
|
||||
# In any case we need to pass `chat_template` if it is available
|
||||
processor_dict = {}
|
||||
if "chat_template" in kwargs:
|
||||
processor_dict = {"chat_template": kwargs.pop("chat_template")}
|
||||
return processor_dict, kwargs
|
||||
|
||||
try:
|
||||
# Load processor dict
|
||||
@ -786,6 +794,9 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"in the processor's config. Make sure to move your template to its own file."
|
||||
)
|
||||
|
||||
if "chat_template" in kwargs:
|
||||
processor_dict["chat_template"] = kwargs.pop("chat_template")
|
||||
|
||||
if not is_local:
|
||||
if "auto_map" in processor_dict:
|
||||
processor_dict["auto_map"] = add_model_info_to_auto_map(
|
||||
@ -817,7 +828,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"""
|
||||
processor_dict = processor_dict.copy()
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
chat_template = kwargs.pop("chat_template", None)
|
||||
|
||||
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
|
||||
# If we don't pop, some specific kwargs will raise a warning
|
||||
@ -829,8 +839,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
|
||||
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
|
||||
processor = cls(*args, **processor_dict)
|
||||
if chat_template is not None:
|
||||
setattr(processor, "chat_template", chat_template)
|
||||
|
||||
# Update processor with kwargs if needed
|
||||
for key in set(kwargs.keys()):
|
||||
@ -1199,12 +1207,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
||||
)
|
||||
|
||||
text_kwargs = {}
|
||||
for key in TextKwargs.__annotations__.keys():
|
||||
value = kwargs.pop(key, None)
|
||||
if value is not None:
|
||||
text_kwargs[key] = value
|
||||
|
||||
chat_template_kwargs = {}
|
||||
for key in ChatTemplateKwargs.__annotations__.keys():
|
||||
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
|
||||
@ -1214,6 +1216,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
tokenize = chat_template_kwargs.pop("tokenize")
|
||||
return_dict = chat_template_kwargs.pop("return_dict")
|
||||
num_frames = chat_template_kwargs.pop("num_frames")
|
||||
video_fps = chat_template_kwargs.pop("video_fps")
|
||||
video_load_backend = chat_template_kwargs.pop("video_load_backend")
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
@ -1221,31 +1224,68 @@ class ProcessorMixin(PushToHubMixin):
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
return_dict=False,
|
||||
**text_kwargs,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
# we will have to return all processed inputs in a dict
|
||||
if isinstance(conversation, (list, tuple)) and (
|
||||
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
||||
):
|
||||
conversations = conversation
|
||||
is_batched = True
|
||||
else:
|
||||
conversations = [conversation]
|
||||
is_batched = False
|
||||
|
||||
if tokenize:
|
||||
images, videos = [], []
|
||||
for message in conversation:
|
||||
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
||||
for vision_info in visuals:
|
||||
if vision_info["type"] == "image":
|
||||
for key in ["image", "url", "path", "base64"]:
|
||||
if key in vision_info:
|
||||
images.append(load_image(vision_info[key]))
|
||||
elif vision_info["type"] == "video":
|
||||
for key in ["video", "url", "path"]:
|
||||
if key in vision_info:
|
||||
videos.append(
|
||||
load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend)
|
||||
)
|
||||
batch_images, batch_videos = [], []
|
||||
for conversation in conversations:
|
||||
images, videos = [], []
|
||||
for message in conversation:
|
||||
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
||||
image_fnames = [
|
||||
vision_info[key]
|
||||
for vision_info in visuals
|
||||
for key in ["image", "url", "path", "base64"]
|
||||
if key in vision_info and vision_info["type"] == "image"
|
||||
]
|
||||
video_fnames = [
|
||||
vision_info[key]
|
||||
for vision_info in visuals
|
||||
for key in ["video", "url", "path"]
|
||||
if key in vision_info and vision_info["type"] == "video"
|
||||
]
|
||||
for fname in image_fnames:
|
||||
images.append(load_image(fname))
|
||||
for fname in video_fnames:
|
||||
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
|
||||
video = [np.array(load_image(image_fname)).T for image_fname in fname]
|
||||
# create a 4D video because `load_video` always returns a 4D array
|
||||
video = np.stack(video)
|
||||
else:
|
||||
video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend)
|
||||
videos.append(video)
|
||||
|
||||
# Currently all processors can accept accept nested list of batches, but not flat list of visuals
|
||||
# So we'll make a batched list of images and let the processor handle it
|
||||
if images:
|
||||
batch_images.append(images)
|
||||
if videos:
|
||||
batch_videos.append(videos)
|
||||
|
||||
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
|
||||
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
|
||||
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling
|
||||
# everything internally. The below line is to keep BC for that and be able to work with model that have
|
||||
# special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
|
||||
# without actionable solution for users
|
||||
single_prompt = prompt[0] if is_batched else prompt
|
||||
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
|
||||
kwargs["add_special_tokens"] = False
|
||||
|
||||
out = self(
|
||||
text=prompt,
|
||||
images=images if images else None,
|
||||
videos=videos if videos else None,
|
||||
images=batch_images if batch_images else None,
|
||||
videos=batch_videos if batch_videos else None,
|
||||
**kwargs,
|
||||
)
|
||||
if return_dict:
|
||||
|
@ -237,6 +237,55 @@ And who is that?<|im_end|>
|
||||
"""
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
# Override as AriaImageProcessor doesn't accept `do_rescale`
|
||||
def test_chat_template_accepts_processing_kwargs(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
padding="max_length",
|
||||
max_length=50,
|
||||
)
|
||||
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
truncation=True,
|
||||
max_length=5,
|
||||
)
|
||||
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
max_image_size=980,
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980])
|
||||
|
||||
# Override as AriaProcessor needs image tokens in prompts
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
|
@ -52,6 +52,11 @@ class Emu3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {
|
||||
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||
} # fmt: skip
|
||||
|
||||
def test_processor_for_generation(self):
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
|
@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -27,7 +27,7 @@ if is_vision_available():
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@ -53,7 +53,11 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"}
|
||||
return {
|
||||
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||
"patch_size": 3,
|
||||
"vision_feature_select_strategy": "default"
|
||||
} # fmt: skip
|
||||
|
||||
@unittest.skip(
|
||||
"Skip because the model has no processor kwargs except for chat template and"
|
||||
@ -123,29 +127,6 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
|
||||
@require_torch
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
||||
|
||||
def test_chat_template_with_continue_final_message(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
||||
|
@ -50,7 +50,11 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"}
|
||||
return {
|
||||
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||
"patch_size": 3,
|
||||
"vision_feature_select_strategy": "default"
|
||||
} # fmt: skip
|
||||
|
||||
@unittest.skip(
|
||||
"Skip because the model has no processor kwargs except for chat template and"
|
||||
|
@ -16,7 +16,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.testing_utils import require_av, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -32,7 +32,7 @@ if is_vision_available():
|
||||
)
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@ -61,7 +61,11 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "dummy_template", "num_image_tokens": 6, "vision_feature_select_strategy": "default"}
|
||||
return {
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"num_image_tokens": 6,
|
||||
"vision_feature_select_strategy": "default"
|
||||
} # fmt: skip
|
||||
|
||||
def test_processor_to_json_string(self):
|
||||
processor = self.get_processor()
|
||||
@ -133,30 +137,3 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
||||
|
||||
@require_torch
|
||||
@require_av
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -52,6 +53,20 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
|
||||
|
||||
def test_chat_template_is_saved(self):
|
||||
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
|
||||
# chat templates aren't serialized to json in processors
|
||||
self.assertFalse("chat_template" in processor_dict_loaded.keys())
|
||||
|
||||
# they have to be saved as separate file and loaded back from that file
|
||||
# so we check if the same template is loaded
|
||||
processor_dict = self.prepare_processor_dict()
|
||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||
|
||||
def test_apply_chat_template(self):
|
||||
# Message contains content which a mix of lists with images and image urls and string
|
||||
messages = [
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -19,7 +20,7 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -45,6 +46,9 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"} # fmt: skip
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
@ -111,3 +115,198 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
batched_messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 2)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||
)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
batched_messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
batched_messages[1][0]["content"].append(
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_video(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Add video URL for return dict and load with `num_frames` arg
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
}
|
||||
num_frames = 3
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 115200)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 288000)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Load without any arg should load the whole video
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8640000)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": [
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
],
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -19,7 +20,7 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -45,6 +46,9 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"} # fmt: skip
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
@ -108,3 +112,198 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
batched_messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 2)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||
)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
batched_messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
batched_messages[1][0]["content"].append(
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_video(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Add video URL for return dict and load with `num_frames` arg
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
}
|
||||
num_frames = 3
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 115200)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 288000)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Load without any arg should load the whole video
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8640000)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": [
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
],
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
|
||||
|
@ -27,10 +27,11 @@ from transformers.models.auto.processing_auto import processor_class_from_name
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
require_av,
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
@ -38,6 +39,9 @@ global_rng = random.Random()
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def prepare_image_inputs():
|
||||
"""This function prepares a list of PIL images"""
|
||||
@ -131,8 +135,10 @@ class ProcessorTesterMixin:
|
||||
processor = self.get_processor()
|
||||
obj = json.loads(processor.to_json_string())
|
||||
for key, value in self.prepare_processor_dict().items():
|
||||
self.assertEqual(obj[key], value)
|
||||
self.assertEqual(getattr(processor, key, None), value)
|
||||
# Chat template is saved as a separate file
|
||||
if key not in "chat_template":
|
||||
self.assertEqual(obj[key], value)
|
||||
self.assertEqual(getattr(processor, key, None), value)
|
||||
|
||||
def test_processor_from_and_save_pretrained(self):
|
||||
processor_first = self.get_processor()
|
||||
@ -532,6 +538,10 @@ class ProcessorTesterMixin:
|
||||
|
||||
def test_chat_template_save_loading(self):
|
||||
processor = self.get_processor()
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "chat_template" not in {*signature.parameters.keys()}:
|
||||
self.skipTest("Processor doesn't accept chat templates at input")
|
||||
|
||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||
processor.chat_template = "test template"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@ -553,3 +563,298 @@ class ProcessorTesterMixin:
|
||||
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
add_special_tokens = True
|
||||
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||
add_special_tokens = False
|
||||
expected_output = processor.tokenizer(
|
||||
formatted_prompt, return_tensors=None, add_special_tokens=add_special_tokens
|
||||
).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 1)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
batched_messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 2)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||
)
|
||||
add_special_tokens = True
|
||||
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||
add_special_tokens = False
|
||||
expected_output = processor.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors=None,
|
||||
padding=True,
|
||||
add_special_tokens=add_special_tokens,
|
||||
).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Now test the ability to return dict
|
||||
batched_messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
batched_messages[1][0]["content"].append(
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(
|
||||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
self.assertTrue(self.images_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 2)
|
||||
|
||||
def test_chat_template_accepts_processing_kwargs(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
padding="max_length",
|
||||
max_length=50,
|
||||
)
|
||||
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
truncation=True,
|
||||
max_length=5,
|
||||
)
|
||||
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
)
|
||||
out_dict = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
do_rescale=True,
|
||||
rescale_factor=-1,
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
|
||||
|
||||
@require_torch
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.images_input_name in out_dict_tensors)
|
||||
for k in out_dict_tensors:
|
||||
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_video(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
add_special_tokens = True
|
||||
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||
add_special_tokens = False
|
||||
expected_output = processor.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors=None,
|
||||
add_special_tokens=add_special_tokens,
|
||||
).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Add video URL for return dict and load with `num_frames` arg
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
}
|
||||
num_frames = 3
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), num_frames)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Load without any arg should load the whole video
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 300)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": [
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
],
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||
|
Loading…
Reference in New Issue
Block a user