mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Make chat templates part of ProcessorMixin (#30744)
* Let's try moving chat templates out of IDEFICS and into the generic ProcessorMixin * Chat templates should not be mandatory * Chat templates should not be mandatory * Not all classes will have default chat templates * stash commit * Add chat template docstring * Clean up docstring * Add chat templates to LLaVA/LLaVA-next * Docstring fixup * Quick IDEFICS2 fixup * Remove some old references to the Conversation class * make fixup
This commit is contained in:
parent
3c4a8dca0c
commit
15b3923d65
@ -16,7 +16,7 @@
|
|||||||
Processor class for IDEFICS2.
|
Processor class for IDEFICS2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||||
@ -56,13 +56,15 @@ class Idefics2Processor(ProcessorMixin):
|
|||||||
The length of the image sequence i.e. the number of <image> tokens per image in the input.
|
The length of the image sequence i.e. the number of <image> tokens per image in the input.
|
||||||
This parameter is used to build the string from the input prompt and image tokens and should match the
|
This parameter is used to build the string from the input prompt and image tokens and should match the
|
||||||
config.perceiver_config.resampler_n_latents value for the model used.
|
config.perceiver_config.resampler_n_latents value for the model used.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
image_processor_class = "Idefics2ImageProcessor"
|
image_processor_class = "Idefics2ImageProcessor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs):
|
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: str = None, **kwargs):
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise ValueError("You need to specify an `image_processor`.")
|
raise ValueError("You need to specify an `image_processor`.")
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
@ -78,10 +80,7 @@ class Idefics2Processor(ProcessorMixin):
|
|||||||
}
|
}
|
||||||
tokenizer.add_special_tokens(tokens_to_add)
|
tokenizer.add_special_tokens(tokens_to_add)
|
||||||
|
|
||||||
# Stores a Jinja template that formats chat histories into tokenizable strings
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
self.chat_template = kwargs.pop("chat_template", None)
|
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
|
||||||
|
|
||||||
def _extract_images_from_prompts(self, prompts):
|
def _extract_images_from_prompts(self, prompts):
|
||||||
prompt_images = []
|
prompt_images = []
|
||||||
@ -252,49 +251,6 @@ class Idefics2Processor(ProcessorMixin):
|
|||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
|
|
||||||
def apply_chat_template(
|
|
||||||
self,
|
|
||||||
conversation: Union[List[Dict[str, str]]],
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
tokenize: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default
|
|
||||||
if no chat template is provided.
|
|
||||||
|
|
||||||
By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert
|
|
||||||
the image token <image> into the sequence according to the message, but does not handle expanding the image
|
|
||||||
tokens to the sequence length or adding the surrounding tokens e.g. <fake_image_token>.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation (`Union[List[Dict, str, str]]`):
|
|
||||||
The conversation to format.
|
|
||||||
chat_template (`Optional[str]`, *optional*):
|
|
||||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template
|
|
||||||
is used.
|
|
||||||
tokenize (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to tokenize the output or not.
|
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments for the tokenizer's `apply_chat_template` method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if chat_template is None:
|
|
||||||
if self.chat_template is not None:
|
|
||||||
chat_template = self.chat_template
|
|
||||||
else:
|
|
||||||
logger.warning_once(
|
|
||||||
"No chat template is set for this processor, falling back to a default class-level template. This is "
|
|
||||||
"very error-prone, because models are often trained with templates different from the class default! "
|
|
||||||
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
|
||||||
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
|
||||||
"then to ensure that this model continues working without issues."
|
|
||||||
)
|
|
||||||
chat_template = self.default_chat_template
|
|
||||||
return self.tokenizer.apply_chat_template(
|
|
||||||
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_chat_template(self):
|
def default_chat_template(self):
|
||||||
"""
|
"""
|
||||||
|
@ -37,14 +37,16 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
The image processor is a required input.
|
The image processor is a required input.
|
||||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||||
The tokenizer is a required input.
|
The tokenizer is a required input.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
image_processor_class = "AutoImageProcessor"
|
image_processor_class = "AutoImageProcessor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor=None, tokenizer=None):
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -37,14 +37,16 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||||||
The image processor is a required input.
|
The image processor is a required input.
|
||||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||||
The tokenizer is a required input.
|
The tokenizer is a required input.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
image_processor_class = "AutoImageProcessor"
|
image_processor_class = "AutoImageProcessor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor=None, tokenizer=None):
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -22,7 +22,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
@ -60,6 +60,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["feature_extractor", "tokenizer"]
|
attributes = ["feature_extractor", "tokenizer"]
|
||||||
|
optional_attributes = ["chat_template"]
|
||||||
# Names need to be attr_class for attr in attributes
|
# Names need to be attr_class for attr in attributes
|
||||||
feature_extractor_class = None
|
feature_extractor_class = None
|
||||||
tokenizer_class = None
|
tokenizer_class = None
|
||||||
@ -67,6 +68,10 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
|
|
||||||
# args have to match the attributes class attribute
|
# args have to match the attributes class attribute
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
# First, extract optional attributes from kwargs if present
|
||||||
|
# Optional attributes can never be positional arguments
|
||||||
|
for optional_attribute in self.optional_attributes:
|
||||||
|
setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
|
||||||
# Sanitize args and kwargs
|
# Sanitize args and kwargs
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
if key not in self.attributes:
|
if key not in self.attributes:
|
||||||
@ -522,6 +527,51 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
first_attribute = getattr(self, self.attributes[0])
|
first_attribute = getattr(self, self.attributes[0])
|
||||||
return getattr(first_attribute, "model_input_names", None)
|
return getattr(first_attribute, "model_input_names", None)
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation: Union[List[Dict[str, str]]],
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
tokenize: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
||||||
|
conversations to turn them into a single tokenizable string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation (`List[Dict, str, str]`):
|
||||||
|
The conversation to format.
|
||||||
|
chat_template (`Optional[str]`, *optional*):
|
||||||
|
The Jinja template to use for formatting the conversation. If not provided, the default chat template
|
||||||
|
is used.
|
||||||
|
tokenize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tokenize the output or not.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
if chat_template is None:
|
||||||
|
if self.chat_template is not None:
|
||||||
|
chat_template = self.chat_template
|
||||||
|
elif getattr(self, "default_chat_template", None) is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"No chat template is set for this processor, falling back to a default class-level template. This is "
|
||||||
|
"very error-prone, because models are often trained with templates different from the class default! "
|
||||||
|
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
||||||
|
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
||||||
|
"then to ensure that this model continues working without issues."
|
||||||
|
)
|
||||||
|
chat_template = self.default_chat_template
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No chat template is set for this processor. Please either set the `chat_template` attribute, "
|
||||||
|
"or provide a chat template as an argument. See "
|
||||||
|
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
||||||
|
)
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
||||||
if ProcessorMixin.push_to_hub.__doc__ is not None:
|
if ProcessorMixin.push_to_hub.__doc__ is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user