Make chat templates part of ProcessorMixin (#30744)

* Let's try moving chat templates out of IDEFICS and into the generic ProcessorMixin

* Chat templates should not be mandatory

* Chat templates should not be mandatory

* Not all classes will have default chat templates

* stash commit

* Add chat template docstring

* Clean up docstring

* Add chat templates to LLaVA/LLaVA-next

* Docstring fixup

* Quick IDEFICS2 fixup

* Remove some old references to the Conversation class

* make fixup
This commit is contained in:
Matt 2024-06-13 14:35:30 +01:00 committed by GitHub
parent 3c4a8dca0c
commit 15b3923d65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 54 deletions

View File

@ -16,7 +16,7 @@
Processor class for IDEFICS2.
"""
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
@ -56,13 +56,15 @@ class Idefics2Processor(ProcessorMixin):
The length of the image sequence i.e. the number of <image> tokens per image in the input.
This parameter is used to build the string from the input prompt and image tokens and should match the
config.perceiver_config.resampler_n_latents value for the model used.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "Idefics2ImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs):
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: str = None, **kwargs):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
@ -78,10 +80,7 @@ class Idefics2Processor(ProcessorMixin):
}
tokenizer.add_special_tokens(tokens_to_add)
# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)
super().__init__(image_processor, tokenizer)
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def _extract_images_from_prompts(self, prompts):
prompt_images = []
@ -252,49 +251,6 @@ class Idefics2Processor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
chat_template: Optional[str] = None,
tokenize: bool = False,
**kwargs,
) -> str:
"""
Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default
if no chat template is provided.
By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert
the image token <image> into the sequence according to the message, but does not handle expanding the image
tokens to the sequence length or adding the surrounding tokens e.g. <fake_image_token>.
Args:
conversation (`Union[List[Dict, str, str]]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments for the tokenizer's `apply_chat_template` method.
"""
if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
else:
logger.warning_once(
"No chat template is set for this processor, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
chat_template = self.default_chat_template
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)
@property
def default_chat_template(self):
"""

View File

@ -37,14 +37,16 @@ class LlavaProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,

View File

@ -37,14 +37,16 @@ class LlavaNextProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,

View File

@ -22,7 +22,7 @@ import json
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from .dynamic_module_utils import custom_object_save
from .tokenization_utils_base import PreTrainedTokenizerBase
@ -60,6 +60,7 @@ class ProcessorMixin(PushToHubMixin):
"""
attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
tokenizer_class = None
@ -67,6 +68,10 @@ class ProcessorMixin(PushToHubMixin):
# args have to match the attributes class attribute
def __init__(self, *args, **kwargs):
# First, extract optional attributes from kwargs if present
# Optional attributes can never be positional arguments
for optional_attribute in self.optional_attributes:
setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
# Sanitize args and kwargs
for key in kwargs:
if key not in self.attributes:
@ -522,6 +527,51 @@ class ProcessorMixin(PushToHubMixin):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
chat_template: Optional[str] = None,
tokenize: bool = False,
**kwargs,
) -> str:
"""
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string.
Args:
conversation (`List[Dict, str, str]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments
"""
if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
elif getattr(self, "default_chat_template", None) is not None:
logger.warning_once(
"No chat template is set for this processor, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
chat_template = self.default_chat_template
else:
raise ValueError(
"No chat template is set for this processor. Please either set the `chat_template` attribute, "
"or provide a chat template as an argument. See "
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
)
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None: