From 15b3923d656cf3e1a0c20d3ddc356eb89a703ce0 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 13 Jun 2024 14:35:30 +0100 Subject: [PATCH] Make chat templates part of ProcessorMixin (#30744) * Let's try moving chat templates out of IDEFICS and into the generic ProcessorMixin * Chat templates should not be mandatory * Chat templates should not be mandatory * Not all classes will have default chat templates * stash commit * Add chat template docstring * Clean up docstring * Add chat templates to LLaVA/LLaVA-next * Docstring fixup * Quick IDEFICS2 fixup * Remove some old references to the Conversation class * make fixup --- .../models/idefics2/processing_idefics2.py | 54 ++----------------- .../models/llava/processing_llava.py | 6 ++- .../llava_next/processing_llava_next.py | 6 ++- src/transformers/processing_utils.py | 52 +++++++++++++++++- 4 files changed, 64 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index e9f9f923373..4edb1813b8e 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -16,7 +16,7 @@ Processor class for IDEFICS2. """ -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image @@ -56,13 +56,15 @@ class Idefics2Processor(ProcessorMixin): The length of the image sequence i.e. the number of tokens per image in the input. This parameter is used to build the string from the input prompt and image tokens and should match the config.perceiver_config.resampler_n_latents value for the model used. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "Idefics2ImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs): + def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: str = None, **kwargs): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: @@ -78,10 +80,7 @@ class Idefics2Processor(ProcessorMixin): } tokenizer.add_special_tokens(tokens_to_add) - # Stores a Jinja template that formats chat histories into tokenizable strings - self.chat_template = kwargs.pop("chat_template", None) - - super().__init__(image_processor, tokenizer) + super().__init__(image_processor, tokenizer, chat_template=chat_template) def _extract_images_from_prompts(self, prompts): prompt_images = [] @@ -252,49 +251,6 @@ class Idefics2Processor(ProcessorMixin): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - def apply_chat_template( - self, - conversation: Union[List[Dict[str, str]]], - chat_template: Optional[str] = None, - tokenize: bool = False, - **kwargs, - ) -> str: - """ - Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default - if no chat template is provided. - - By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert - the image token into the sequence according to the message, but does not handle expanding the image - tokens to the sequence length or adding the surrounding tokens e.g. . - - Args: - conversation (`Union[List[Dict, str, str]]`): - The conversation to format. - chat_template (`Optional[str]`, *optional*): - The Jinja template to use for formatting the conversation. If not provided, the default chat template - is used. - tokenize (`bool`, *optional*, defaults to `False`): - Whether to tokenize the output or not. - **kwargs: - Additional keyword arguments for the tokenizer's `apply_chat_template` method. - """ - - if chat_template is None: - if self.chat_template is not None: - chat_template = self.chat_template - else: - logger.warning_once( - "No chat template is set for this processor, falling back to a default class-level template. This is " - "very error-prone, because models are often trained with templates different from the class default! " - "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " - "point any code depending on them will stop working. We recommend setting a valid chat template before " - "then to ensure that this model continues working without issues." - ) - chat_template = self.default_chat_template - return self.tokenizer.apply_chat_template( - conversation, chat_template=chat_template, tokenize=tokenize, **kwargs - ) - @property def default_chat_template(self): """ diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 7016cd50096..96d38c53c94 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -37,14 +37,16 @@ class LlavaProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None): - super().__init__(image_processor, tokenizer) + def __init__(self, image_processor=None, tokenizer=None, chat_template=None): + super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 91cd544ab64..6c2ca2f9028 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -37,14 +37,16 @@ class LlavaNextProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None): - super().__init__(image_processor, tokenizer) + def __init__(self, image_processor=None, tokenizer=None, chat_template=None): + super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index d76fa4dcccc..a21d265b9d1 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -22,7 +22,7 @@ import json import os import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from .dynamic_module_utils import custom_object_save from .tokenization_utils_base import PreTrainedTokenizerBase @@ -60,6 +60,7 @@ class ProcessorMixin(PushToHubMixin): """ attributes = ["feature_extractor", "tokenizer"] + optional_attributes = ["chat_template"] # Names need to be attr_class for attr in attributes feature_extractor_class = None tokenizer_class = None @@ -67,6 +68,10 @@ class ProcessorMixin(PushToHubMixin): # args have to match the attributes class attribute def __init__(self, *args, **kwargs): + # First, extract optional attributes from kwargs if present + # Optional attributes can never be positional arguments + for optional_attribute in self.optional_attributes: + setattr(self, optional_attribute, kwargs.pop(optional_attribute, None)) # Sanitize args and kwargs for key in kwargs: if key not in self.attributes: @@ -522,6 +527,51 @@ class ProcessorMixin(PushToHubMixin): first_attribute = getattr(self, self.attributes[0]) return getattr(first_attribute, "model_input_names", None) + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]]], + chat_template: Optional[str] = None, + tokenize: bool = False, + **kwargs, + ) -> str: + """ + Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input + conversations to turn them into a single tokenizable string. + + Args: + conversation (`List[Dict, str, str]`): + The conversation to format. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the default chat template + is used. + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + **kwargs: + Additional keyword arguments + """ + + if chat_template is None: + if self.chat_template is not None: + chat_template = self.chat_template + elif getattr(self, "default_chat_template", None) is not None: + logger.warning_once( + "No chat template is set for this processor, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + chat_template = self.default_chat_template + else: + raise ValueError( + "No chat template is set for this processor. Please either set the `chat_template` attribute, " + "or provide a chat template as an argument. See " + "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." + ) + return self.tokenizer.apply_chat_template( + conversation, chat_template=chat_template, tokenize=tokenize, **kwargs + ) + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: