mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Chat template: update for processor (#35953)
* update * we need batched nested input to always process correctly * update a bit * fix copies
This commit is contained in:
parent
5bd7694781
commit
eebd2c972c
@ -562,7 +562,7 @@ def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] =
|
|||||||
return indices
|
return indices
|
||||||
|
|
||||||
|
|
||||||
def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Decode the video with open-cv decoder.
|
Decode the video with open-cv decoder.
|
||||||
|
|
||||||
@ -570,13 +570,25 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
|||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||||
|
If not specified and `fps==None`, all frames are sampled.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||||
"""
|
"""
|
||||||
video = cv2.VideoCapture(video_path)
|
video = cv2.VideoCapture(video_path)
|
||||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
num_frames = int(total_num_frames / video_fps * fps)
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||||
|
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||||
|
)
|
||||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
@ -595,7 +607,7 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
|
|||||||
return np.stack(frames)
|
return np.stack(frames)
|
||||||
|
|
||||||
|
|
||||||
def read_video_decord(video_path: str, num_frames: Optional[int] = None):
|
def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Decode the video with Decord decoder.
|
Decode the video with Decord decoder.
|
||||||
|
|
||||||
@ -603,18 +615,31 @@ def read_video_decord(video_path: str, num_frames: Optional[int] = None):
|
|||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||||
|
If not specified and `fps==None`, all frames are sampled.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||||
"""
|
"""
|
||||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||||
indices = get_uniform_frame_indices(total_num_frames=len(vr), num_frames=num_frames)
|
video_fps = vr.get_avg_fps()
|
||||||
|
total_num_frames = len(vr)
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
num_frames = int(total_num_frames / video_fps * fps)
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||||
|
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||||
|
)
|
||||||
|
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
|
||||||
frames = vr.get_batch(indices).asnumpy()
|
frames = vr.get_batch(indices).asnumpy()
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Decode the video with PyAV decoder.
|
Decode the video with PyAV decoder.
|
||||||
|
|
||||||
@ -622,15 +647,26 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
|||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||||
|
If not specified and `fps==None`, all frames are sampled.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||||
"""
|
"""
|
||||||
container = av.open(video_path)
|
container = av.open(video_path)
|
||||||
|
|
||||||
# sample uniformly "num_frames" frames from the video
|
|
||||||
total_num_frames = container.streams.video[0].frames
|
total_num_frames = container.streams.video[0].frames
|
||||||
|
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
num_frames = int(total_num_frames / video_fps * fps)
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||||
|
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||||
|
)
|
||||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
@ -644,7 +680,7 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
|
|||||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||||
|
|
||||||
|
|
||||||
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Decode the video with torchvision decoder.
|
Decode the video with torchvision decoder.
|
||||||
|
|
||||||
@ -652,7 +688,11 @@ def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
|||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||||
|
If not specified and `fps==None`, all frames are sampled.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||||
@ -664,6 +704,15 @@ def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
|
|||||||
pts_unit="sec",
|
pts_unit="sec",
|
||||||
output_format="TCHW",
|
output_format="TCHW",
|
||||||
)
|
)
|
||||||
|
video_fps = info["video_fps"]
|
||||||
|
total_num_frames = video.size(0) - 1
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
num_frames = int(total_num_frames / video_fps * fps)
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||||
|
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||||
|
)
|
||||||
|
|
||||||
if num_frames is not None:
|
if num_frames is not None:
|
||||||
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
|
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
|
||||||
@ -680,7 +729,12 @@ VIDEO_DECODERS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None, backend: str = "opencv") -> np.array:
|
def load_video(
|
||||||
|
video: Union[str, "VideoInput"],
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
backend: str = "opencv",
|
||||||
|
) -> np.array:
|
||||||
"""
|
"""
|
||||||
Loads `video` to a numpy array.
|
Loads `video` to a numpy array.
|
||||||
|
|
||||||
@ -689,12 +743,19 @@ def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None
|
|||||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`np.array`: A numpy array of shape (num_frames, channels, height, width).
|
`np.array`: A numpy array of shape (num_frames, channels, height, width).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if fps is not None and num_frames is not None:
|
||||||
|
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||||
|
|
||||||
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
||||||
if not is_yt_dlp_available():
|
if not is_yt_dlp_available():
|
||||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||||
@ -735,7 +796,7 @@ def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None
|
|||||||
)
|
)
|
||||||
|
|
||||||
video_decoder = VIDEO_DECODERS[backend]
|
video_decoder = VIDEO_DECODERS[backend]
|
||||||
video = video_decoder(file_obj)
|
video = video_decoder(file_obj, num_frames=num_frames, fps=fps)
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,6 +110,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
|||||||
The resampling filter to use if resizing the image.
|
The resampling filter to use if resizing the image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_mean: List[float] = None,
|
image_mean: List[float] = None,
|
||||||
|
@ -499,6 +499,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
|||||||
The resampling filter to use if resizing the image.
|
The resampling filter to use if resizing the image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_mean: List[float] = None,
|
image_mean: List[float] = None,
|
||||||
@ -997,6 +999,10 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
def model_input_names(self):
|
def model_input_names(self):
|
||||||
tokenizer_input_names = self.tokenizer.model_input_names
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
|
||||||
|
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when remocing
|
||||||
|
# otherwise `self.image_processor.model_input_names` is also modified
|
||||||
|
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
||||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,6 +158,10 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
def model_input_names(self):
|
def model_input_names(self):
|
||||||
tokenizer_input_names = self.tokenizer.model_input_names
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
|
||||||
|
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when remocing
|
||||||
|
# otherwise `self.image_processor.model_input_names` is also modified
|
||||||
|
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
||||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class Emu3ImageProcessor(BaseImageProcessor):
|
|||||||
The spatial downsample factor the image will be downsampled in feature extracting phase
|
The spatial downsample factor the image will be downsampled in feature extracting phase
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values", "image_sizes"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -63,6 +63,7 @@ class Emu3Processor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["chat_template"]
|
||||||
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
||||||
image_processor_class = "Emu3ImageProcessor"
|
image_processor_class = "Emu3ImageProcessor"
|
||||||
|
|
||||||
@ -179,7 +180,7 @@ class Emu3Processor(ProcessorMixin):
|
|||||||
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
data.update(**image_features)
|
data.update(**image_features)
|
||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
|
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].pop("return_tensors", None))
|
||||||
|
|
||||||
def calculate_generate_size(self, ratio, image_area, spatial_factor):
|
def calculate_generate_size(self, ratio, image_area, spatial_factor):
|
||||||
width, height = map(int, ratio.split(":"))
|
width, height = map(int, ratio.split(":"))
|
||||||
|
@ -184,7 +184,7 @@ class Idefics2ImageProcessor(BaseImageProcessor):
|
|||||||
strategy was first introduced in https://arxiv.org/abs/2311.06607.
|
strategy was first introduced in https://arxiv.org/abs/2311.06607.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -289,7 +289,7 @@ class Idefics3ImageProcessor(BaseImageProcessor):
|
|||||||
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
|
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -129,6 +129,7 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["image_seq_len", "chat_template"]
|
||||||
image_processor_class = "Idefics3ImageProcessor"
|
image_processor_class = "Idefics3ImageProcessor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
@ -354,7 +355,7 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
def model_input_names(self):
|
def model_input_names(self):
|
||||||
tokenizer_input_names = self.tokenizer.model_input_names
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Idefics3Processor"]
|
__all__ = ["Idefics3Processor"]
|
||||||
|
@ -163,7 +163,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
|||||||
Whether to convert the image to RGB.
|
Whether to convert the image to RGB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values", "image_sizes"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -202,14 +202,17 @@ class MllamaProcessor(ProcessorMixin):
|
|||||||
The image processor is a required input.
|
The image processor is a required input.
|
||||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
||||||
The tokenizer is a required input.
|
The tokenizer is a required input.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["chat_template"]
|
||||||
image_processor_class = "MllamaImageProcessor"
|
image_processor_class = "MllamaImageProcessor"
|
||||||
tokenizer_class = "PreTrainedTokenizerFast"
|
tokenizer_class = "PreTrainedTokenizerFast"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer):
|
def __init__(self, image_processor, tokenizer, chat_template=None):
|
||||||
if not hasattr(tokenizer, "image_token"):
|
if not hasattr(tokenizer, "image_token"):
|
||||||
self.image_token = "<|image|>"
|
self.image_token = "<|image|>"
|
||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
@ -220,8 +223,7 @@ class MllamaProcessor(ProcessorMixin):
|
|||||||
self.python_token = "<|python_tag|>"
|
self.python_token = "<|python_tag|>"
|
||||||
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
||||||
self.bos_token = tokenizer.bos_token
|
self.bos_token = tokenizer.bos_token
|
||||||
self.chat_template = tokenizer.chat_template
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
super().__init__(image_processor, tokenizer)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -364,6 +366,10 @@ class MllamaProcessor(ProcessorMixin):
|
|||||||
def model_input_names(self):
|
def model_input_names(self):
|
||||||
tokenizer_input_names = self.tokenizer.model_input_names
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
|
||||||
|
# Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when remocing
|
||||||
|
# otherwise `self.image_processor.model_input_names` is also modified
|
||||||
|
image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"]
|
||||||
return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
|
return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -379,6 +379,9 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
|||||||
The backend to use when loading the video which will be used only when there are videos in the conversation.
|
The backend to use when loading the video which will be used only when there are videos in the conversation.
|
||||||
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
|
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
|
||||||
that supports all types of sources to load from.
|
that supports all types of sources to load from.
|
||||||
|
video_fps (`int`, *optional*):
|
||||||
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenize: Optional[bool] = False
|
tokenize: Optional[bool] = False
|
||||||
@ -390,6 +393,7 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
|||||||
return_assistant_tokens_mask: Optional[bool] = False
|
return_assistant_tokens_mask: Optional[bool] = False
|
||||||
num_frames: Optional[int] = None
|
num_frames: Optional[int] = None
|
||||||
video_load_backend: Optional[str] = "pyav"
|
video_load_backend: Optional[str] = "pyav"
|
||||||
|
video_fps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class AllKwargsForChatTemplate(
|
class AllKwargsForChatTemplate(
|
||||||
@ -762,7 +766,11 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
||||||
# However, for models added in the future, we won't get the expected error if this file is missing.
|
# However, for models added in the future, we won't get the expected error if this file is missing.
|
||||||
if resolved_processor_file is None:
|
if resolved_processor_file is None:
|
||||||
return {}, kwargs
|
# In any case we need to pass `chat_template` if it is available
|
||||||
|
processor_dict = {}
|
||||||
|
if "chat_template" in kwargs:
|
||||||
|
processor_dict = {"chat_template": kwargs.pop("chat_template")}
|
||||||
|
return processor_dict, kwargs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load processor dict
|
# Load processor dict
|
||||||
@ -786,6 +794,9 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
"in the processor's config. Make sure to move your template to its own file."
|
"in the processor's config. Make sure to move your template to its own file."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "chat_template" in kwargs:
|
||||||
|
processor_dict["chat_template"] = kwargs.pop("chat_template")
|
||||||
|
|
||||||
if not is_local:
|
if not is_local:
|
||||||
if "auto_map" in processor_dict:
|
if "auto_map" in processor_dict:
|
||||||
processor_dict["auto_map"] = add_model_info_to_auto_map(
|
processor_dict["auto_map"] = add_model_info_to_auto_map(
|
||||||
@ -817,7 +828,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
processor_dict = processor_dict.copy()
|
processor_dict = processor_dict.copy()
|
||||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||||
chat_template = kwargs.pop("chat_template", None)
|
|
||||||
|
|
||||||
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
|
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
|
||||||
# If we don't pop, some specific kwargs will raise a warning
|
# If we don't pop, some specific kwargs will raise a warning
|
||||||
@ -829,8 +839,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
|
|
||||||
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
|
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
|
||||||
processor = cls(*args, **processor_dict)
|
processor = cls(*args, **processor_dict)
|
||||||
if chat_template is not None:
|
|
||||||
setattr(processor, "chat_template", chat_template)
|
|
||||||
|
|
||||||
# Update processor with kwargs if needed
|
# Update processor with kwargs if needed
|
||||||
for key in set(kwargs.keys()):
|
for key in set(kwargs.keys()):
|
||||||
@ -1199,12 +1207,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
||||||
)
|
)
|
||||||
|
|
||||||
text_kwargs = {}
|
|
||||||
for key in TextKwargs.__annotations__.keys():
|
|
||||||
value = kwargs.pop(key, None)
|
|
||||||
if value is not None:
|
|
||||||
text_kwargs[key] = value
|
|
||||||
|
|
||||||
chat_template_kwargs = {}
|
chat_template_kwargs = {}
|
||||||
for key in ChatTemplateKwargs.__annotations__.keys():
|
for key in ChatTemplateKwargs.__annotations__.keys():
|
||||||
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
|
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
|
||||||
@ -1214,6 +1216,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
tokenize = chat_template_kwargs.pop("tokenize")
|
tokenize = chat_template_kwargs.pop("tokenize")
|
||||||
return_dict = chat_template_kwargs.pop("return_dict")
|
return_dict = chat_template_kwargs.pop("return_dict")
|
||||||
num_frames = chat_template_kwargs.pop("num_frames")
|
num_frames = chat_template_kwargs.pop("num_frames")
|
||||||
|
video_fps = chat_template_kwargs.pop("video_fps")
|
||||||
video_load_backend = chat_template_kwargs.pop("video_load_backend")
|
video_load_backend = chat_template_kwargs.pop("video_load_backend")
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
@ -1221,31 +1224,68 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
**text_kwargs,
|
|
||||||
**chat_template_kwargs,
|
**chat_template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we will have to return all processed inputs in a dict
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
||||||
|
):
|
||||||
|
conversations = conversation
|
||||||
|
is_batched = True
|
||||||
|
else:
|
||||||
|
conversations = [conversation]
|
||||||
|
is_batched = False
|
||||||
|
|
||||||
if tokenize:
|
if tokenize:
|
||||||
|
batch_images, batch_videos = [], []
|
||||||
|
for conversation in conversations:
|
||||||
images, videos = [], []
|
images, videos = [], []
|
||||||
for message in conversation:
|
for message in conversation:
|
||||||
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
||||||
for vision_info in visuals:
|
image_fnames = [
|
||||||
if vision_info["type"] == "image":
|
vision_info[key]
|
||||||
for key in ["image", "url", "path", "base64"]:
|
for vision_info in visuals
|
||||||
if key in vision_info:
|
for key in ["image", "url", "path", "base64"]
|
||||||
images.append(load_image(vision_info[key]))
|
if key in vision_info and vision_info["type"] == "image"
|
||||||
elif vision_info["type"] == "video":
|
]
|
||||||
for key in ["video", "url", "path"]:
|
video_fnames = [
|
||||||
if key in vision_info:
|
vision_info[key]
|
||||||
videos.append(
|
for vision_info in visuals
|
||||||
load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend)
|
for key in ["video", "url", "path"]
|
||||||
)
|
if key in vision_info and vision_info["type"] == "video"
|
||||||
|
]
|
||||||
|
for fname in image_fnames:
|
||||||
|
images.append(load_image(fname))
|
||||||
|
for fname in video_fnames:
|
||||||
|
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
|
||||||
|
video = [np.array(load_image(image_fname)).T for image_fname in fname]
|
||||||
|
# create a 4D video because `load_video` always returns a 4D array
|
||||||
|
video = np.stack(video)
|
||||||
|
else:
|
||||||
|
video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend)
|
||||||
|
videos.append(video)
|
||||||
|
|
||||||
|
# Currently all processors can accept accept nested list of batches, but not flat list of visuals
|
||||||
|
# So we'll make a batched list of images and let the processor handle it
|
||||||
|
if images:
|
||||||
|
batch_images.append(images)
|
||||||
|
if videos:
|
||||||
|
batch_videos.append(videos)
|
||||||
|
|
||||||
|
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
|
||||||
|
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
|
||||||
|
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling
|
||||||
|
# everything internally. The below line is to keep BC for that and be able to work with model that have
|
||||||
|
# special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
|
||||||
|
# without actionable solution for users
|
||||||
|
single_prompt = prompt[0] if is_batched else prompt
|
||||||
|
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
|
||||||
|
kwargs["add_special_tokens"] = False
|
||||||
|
|
||||||
out = self(
|
out = self(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
images=images if images else None,
|
images=batch_images if batch_images else None,
|
||||||
videos=videos if videos else None,
|
videos=batch_videos if batch_videos else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if return_dict:
|
if return_dict:
|
||||||
|
@ -237,6 +237,55 @@ And who is that?<|im_end|>
|
|||||||
"""
|
"""
|
||||||
self.assertEqual(rendered, expected_rendered)
|
self.assertEqual(rendered, expected_rendered)
|
||||||
|
|
||||||
|
# Override as AriaImageProcessor doesn't accept `do_rescale`
|
||||||
|
def test_chat_template_accepts_processing_kwargs(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=50,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=5,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
max_image_size=980,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980])
|
||||||
|
|
||||||
# Override as AriaProcessor needs image tokens in prompts
|
# Override as AriaProcessor needs image tokens in prompts
|
||||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
|
@ -52,6 +52,11 @@ class Emu3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
processor.save_pretrained(self.tmpdirname)
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {
|
||||||
|
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||||
|
} # fmt: skip
|
||||||
|
|
||||||
def test_processor_for_generation(self):
|
def test_processor_for_generation(self):
|
||||||
processor_components = self.prepare_components()
|
processor_components = self.prepare_components()
|
||||||
processor = self.processor_class(**processor_components)
|
processor = self.processor_class(**processor_components)
|
||||||
|
@ -17,7 +17,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_processing_common import ProcessorTesterMixin
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
@ -27,7 +27,7 @@ if is_vision_available():
|
|||||||
from transformers import CLIPImageProcessor
|
from transformers import CLIPImageProcessor
|
||||||
|
|
||||||
if is_torch_available:
|
if is_torch_available:
|
||||||
import torch
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@ -53,7 +53,11 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
def prepare_processor_dict(self):
|
def prepare_processor_dict(self):
|
||||||
return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"}
|
return {
|
||||||
|
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||||
|
"patch_size": 3,
|
||||||
|
"vision_feature_select_strategy": "default"
|
||||||
|
} # fmt: skip
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"Skip because the model has no processor kwargs except for chat template and"
|
"Skip because the model has no processor kwargs except for chat template and"
|
||||||
@ -123,29 +127,6 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_chat_template_dict_torch(self):
|
|
||||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
|
||||||
{"type": "text", "text": "What is shown in this image?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
out_dict_tensors = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
|
||||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
|
||||||
|
|
||||||
def test_chat_template_with_continue_final_message(self):
|
def test_chat_template_with_continue_final_message(self):
|
||||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||||
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
||||||
|
@ -50,7 +50,11 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
def prepare_processor_dict(self):
|
def prepare_processor_dict(self):
|
||||||
return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"}
|
return {
|
||||||
|
"chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
|
||||||
|
"patch_size": 3,
|
||||||
|
"vision_feature_select_strategy": "default"
|
||||||
|
} # fmt: skip
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"Skip because the model has no processor kwargs except for chat template and"
|
"Skip because the model has no processor kwargs except for chat template and"
|
||||||
|
@ -16,7 +16,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
from transformers.testing_utils import require_av, require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_processing_common import ProcessorTesterMixin
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
@ -32,7 +32,7 @@ if is_vision_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_available:
|
if is_torch_available:
|
||||||
import torch
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@ -61,7 +61,11 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||||
|
|
||||||
def prepare_processor_dict(self):
|
def prepare_processor_dict(self):
|
||||||
return {"chat_template": "dummy_template", "num_image_tokens": 6, "vision_feature_select_strategy": "default"}
|
return {
|
||||||
|
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"num_image_tokens": 6,
|
||||||
|
"vision_feature_select_strategy": "default"
|
||||||
|
} # fmt: skip
|
||||||
|
|
||||||
def test_processor_to_json_string(self):
|
def test_processor_to_json_string(self):
|
||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
@ -133,30 +137,3 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
||||||
)
|
)
|
||||||
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@require_av
|
|
||||||
def test_chat_template_dict_torch(self):
|
|
||||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "video",
|
|
||||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
|
||||||
},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
out_dict_tensors = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
|
||||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -52,6 +53,20 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
|
||||||
|
|
||||||
|
def test_chat_template_is_saved(self):
|
||||||
|
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
|
||||||
|
# chat templates aren't serialized to json in processors
|
||||||
|
self.assertFalse("chat_template" in processor_dict_loaded.keys())
|
||||||
|
|
||||||
|
# they have to be saved as separate file and loaded back from that file
|
||||||
|
# so we check if the same template is loaded
|
||||||
|
processor_dict = self.prepare_processor_dict()
|
||||||
|
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||||
|
|
||||||
def test_apply_chat_template(self):
|
def test_apply_chat_template(self):
|
||||||
# Message contains content which a mix of lists with images and image urls and string
|
# Message contains content which a mix of lists with images and image urls and string
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -19,7 +20,7 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
from ...test_processing_common import ProcessorTesterMixin
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
@ -45,6 +46,9 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_image_processor(self, **kwargs):
|
def get_image_processor(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"} # fmt: skip
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
@ -111,3 +115,198 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
||||||
|
|
||||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|
||||||
|
def test_chat_template_single(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||||
|
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
batched_messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 2)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||||
|
)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
batched_messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
batched_messages[1][0]["content"].append(
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video"},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Add video URL for return dict and load with `num_frames` arg
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||||
|
}
|
||||||
|
num_frames = 3
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 115200)
|
||||||
|
|
||||||
|
# Load with `video_fps` arg
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 288000)
|
||||||
|
|
||||||
|
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load without any arg should load the whole video
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8640000)
|
||||||
|
|
||||||
|
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||||
|
# because we assume they come from one video
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": [
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -19,7 +20,7 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
from ...test_processing_common import ProcessorTesterMixin
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
@ -45,6 +46,9 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_image_processor(self, **kwargs):
|
def get_image_processor(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"} # fmt: skip
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
@ -108,3 +112,198 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
|
||||||
|
|
||||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|
||||||
|
def test_chat_template_single(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||||
|
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
batched_messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 2)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||||
|
)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
batched_messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
batched_messages[1][0]["content"].append(
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video"},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Add video URL for return dict and load with `num_frames` arg
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||||
|
}
|
||||||
|
num_frames = 3
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 115200)
|
||||||
|
|
||||||
|
# Load with `video_fps` arg
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 288000)
|
||||||
|
|
||||||
|
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load without any arg should load the whole video
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8640000)
|
||||||
|
|
||||||
|
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||||
|
# because we assume they come from one video
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": [
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
|
||||||
|
@ -27,10 +27,11 @@ from transformers.models.auto.processing_auto import processor_class_from_name
|
|||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
check_json_file_has_correct_format,
|
check_json_file_has_correct_format,
|
||||||
|
require_av,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
@ -38,6 +39,9 @@ global_rng = random.Random()
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def prepare_image_inputs():
|
def prepare_image_inputs():
|
||||||
"""This function prepares a list of PIL images"""
|
"""This function prepares a list of PIL images"""
|
||||||
@ -131,6 +135,8 @@ class ProcessorTesterMixin:
|
|||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
obj = json.loads(processor.to_json_string())
|
obj = json.loads(processor.to_json_string())
|
||||||
for key, value in self.prepare_processor_dict().items():
|
for key, value in self.prepare_processor_dict().items():
|
||||||
|
# Chat template is saved as a separate file
|
||||||
|
if key not in "chat_template":
|
||||||
self.assertEqual(obj[key], value)
|
self.assertEqual(obj[key], value)
|
||||||
self.assertEqual(getattr(processor, key, None), value)
|
self.assertEqual(getattr(processor, key, None), value)
|
||||||
|
|
||||||
@ -532,6 +538,10 @@ class ProcessorTesterMixin:
|
|||||||
|
|
||||||
def test_chat_template_save_loading(self):
|
def test_chat_template_save_loading(self):
|
||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "chat_template" not in {*signature.parameters.keys()}:
|
||||||
|
self.skipTest("Processor doesn't accept chat templates at input")
|
||||||
|
|
||||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||||
processor.chat_template = "test template"
|
processor.chat_template = "test template"
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
@ -553,3 +563,298 @@ class ProcessorTesterMixin:
|
|||||||
# When we save as single files, tokenizers and processors share a chat template, which means
|
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||||
# the reloaded tokenizer should get the chat template as well
|
# the reloaded tokenizer should get the chat template as well
|
||||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||||
|
|
||||||
|
def test_chat_template_single(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
add_special_tokens = True
|
||||||
|
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||||
|
add_special_tokens = False
|
||||||
|
expected_output = processor.tokenizer(
|
||||||
|
formatted_prompt, return_tensors=None, add_special_tokens=add_special_tokens
|
||||||
|
).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 1)
|
||||||
|
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
batched_messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 2)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
|
||||||
|
)
|
||||||
|
add_special_tokens = True
|
||||||
|
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||||
|
add_special_tokens = False
|
||||||
|
expected_output = processor.tokenizer(
|
||||||
|
formatted_prompt,
|
||||||
|
return_tensors=None,
|
||||||
|
padding=True,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
batched_messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
batched_messages[1][0]["content"].append(
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||||||
|
)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict)
|
||||||
|
|
||||||
|
# should always have input_ids and attention_mask
|
||||||
|
self.assertEqual(len(out_dict["input_ids"]), 2)
|
||||||
|
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||||
|
self.assertEqual(len(out_dict[self.images_input_name]), 2)
|
||||||
|
|
||||||
|
def test_chat_template_accepts_processing_kwargs(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=50,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=5,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=-1,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_chat_template_dict_torch(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
out_dict_tensors = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
self.assertTrue(self.images_input_name in out_dict_tensors)
|
||||||
|
for k in out_dict_tensors:
|
||||||
|
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video"},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
self.assertEqual(len(formatted_prompt), 1)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||||
|
add_special_tokens = True
|
||||||
|
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||||
|
add_special_tokens = False
|
||||||
|
expected_output = processor.tokenizer(
|
||||||
|
formatted_prompt,
|
||||||
|
return_tensors=None,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
).input_ids
|
||||||
|
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||||
|
|
||||||
|
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||||
|
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||||
|
|
||||||
|
# Add video URL for return dict and load with `num_frames` arg
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||||
|
}
|
||||||
|
num_frames = 3
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), num_frames)
|
||||||
|
|
||||||
|
# Load with `video_fps` arg
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
|
||||||
|
|
||||||
|
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load without any arg should load the whole video
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 300)
|
||||||
|
|
||||||
|
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||||
|
# because we assume they come from one video
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": [
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user