Pipeline: simple API for assisted generation (#34504)

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-01-08 17:08:02 +00:00 committed by GitHub
parent 3f483beab9
commit 76da6ca034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 172 additions and 18 deletions

View File

@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
``` ```
<Tip>
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
```python
>>> from transformers import pipeline
>>> import torch
>>> pipe = pipeline(
... "text-generation",
... model="meta-llama/Llama-3.1-8B",
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
... torch_dtype=torch.bfloat16
>>> )
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
>>> pipe_output[0]["generated_text"]
'Once upon a time, 3D printing was a niche technology that was only'
```
</Tip>
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.

View File

@ -347,7 +347,6 @@ class FlaxGenerationMixin:
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:

View File

@ -773,7 +773,6 @@ class TFGenerationMixin:
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
use_xla = not tf.executing_eagerly() use_xla = not tf.executing_eagerly()

View File

@ -348,6 +348,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("Only Whisper can return language for now.") raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language postprocess_params["return_language"] = return_language
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):

View File

@ -33,7 +33,7 @@ from ..dynamic_module_utils import custom_object_save
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig from ..models.auto import AutoConfig, AutoTokenizer
from ..processing_utils import ProcessorMixin from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ( from ..utils import (
@ -425,6 +425,62 @@ def get_default_model_and_revision(
return default_models[framework] return default_models[framework]
def load_assistant_model(
model: "PreTrainedModel",
assistant_model: Optional[Union[str, "PreTrainedModel"]],
assistant_tokenizer: Optional[PreTrainedTokenizer],
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
"""
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
Args:
model ([`PreTrainedModel`]):
The main model that will be used by the pipeline to make predictions.
assistant_model (`str` or [`PreTrainedModel`], *optional*):
The assistant model that will be used by the pipeline to make predictions.
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
The assistant tokenizer that will be used by the pipeline to encode data for the model.
Returns:
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
"""
if not model.can_generate() or assistant_model is None:
return None, None
if not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
)
# If the model is passed as a string, load the model and the corresponding tokenizer
if isinstance(assistant_model, str):
assistant_config = AutoConfig.from_pretrained(assistant_model)
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
else:
loaded_assistant_model = assistant_model
loaded_assistant_tokenizer = assistant_tokenizer
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
# tokenizer
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
same_special_tokens = all(
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
)
if same_vocab_size and same_special_tokens:
loaded_assistant_tokenizer = None
elif loaded_assistant_tokenizer is None:
raise ValueError(
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
"tokenizer."
)
return loaded_assistant_model, loaded_assistant_tokenizer
class PipelineException(Exception): class PipelineException(Exception):
""" """
Raised by a [`Pipeline`] when handling __call__. Raised by a [`Pipeline`] when handling __call__.
@ -925,8 +981,13 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
): ):
self.model.to(self.device) self.model.to(self.device)
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model # If the model can generate:
# as we apply local tweaks to the generation config. # 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
# tweaks to the generation config.
# 2 - load the assistant model if it is passed.
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
)
if self.model.can_generate(): if self.model.can_generate():
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
self.generation_config = copy.deepcopy(self.model.generation_config) self.generation_config = copy.deepcopy(self.model.generation_config)

View File

@ -189,7 +189,14 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
if handle_impossible_answer is not None: if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer postprocess_params["handle_impossible_answer"] = handle_impossible_answer
return preprocess_params, {}, postprocess_params forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def __call__( def __call__(
self, self,

View File

@ -92,6 +92,12 @@ class ImageToTextPipeline(Pipeline):
) )
forward_params.update(generate_kwargs) forward_params.update(generate_kwargs)
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, {} return preprocess_params, forward_params, {}
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs): def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):

View File

@ -358,6 +358,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
forward_params = {} forward_params = {}
if sequential is not None: if sequential is not None:
forward_params["sequential"] = sequential forward_params["sequential"] = sequential
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, {} return preprocess_params, forward_params, {}
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None): def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):

View File

@ -106,6 +106,12 @@ class Text2TextGenerationPipeline(Pipeline):
) )
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
def check_inputs(self, input_length: int, min_length: int, max_length: int): def check_inputs(self, input_length: int, min_length: int, max_length: int):

View File

@ -1,7 +1,6 @@
import enum import enum
import itertools import itertools
import types import types
import warnings
from typing import Dict from typing import Dict
from ..utils import add_end_docstrings, is_tf_available, is_torch_available from ..utils import add_end_docstrings, is_tf_available, is_torch_available
@ -194,12 +193,13 @@ class TextGenerationPipeline(Pipeline):
if stop_sequence is not None: if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1: generate_kwargs["eos_token_id"] = stop_sequence_ids
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of" if self.assistant_model is not None:
" the stop sequence will be used as the stop sequence string in the interim." forward_params["assistant_model"] = self.assistant_model
) if self.assistant_tokenizer is not None:
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params

View File

@ -148,10 +148,9 @@ class TextToAudioPipeline(Pipeline):
else: else:
if len(generate_kwargs): if len(generate_kwargs):
raise ValueError( raise ValueError(
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty. "You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
For forward-only TTA models, please use `forward_params` instead of of "empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
`generate_kwargs`. For reference, here are the `generate_kwargs` used here: f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
{generate_kwargs.keys()}"""
) )
output = self.model(**model_inputs, **forward_params)[0] output = self.model(**model_inputs, **forward_params)[0]
@ -191,6 +190,12 @@ class TextToAudioPipeline(Pipeline):
forward_params=None, forward_params=None,
generate_kwargs=None, generate_kwargs=None,
): ):
if self.assistant_model is not None:
generate_kwargs["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
generate_kwargs["tokenizer"] = self.tokenizer
generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
params = { params = {
"forward_params": forward_params if forward_params else {}, "forward_params": forward_params if forward_params else {},
"generate_kwargs": generate_kwargs if generate_kwargs else {}, "generate_kwargs": generate_kwargs if generate_kwargs else {},

View File

@ -66,7 +66,15 @@ class VisualQuestionAnsweringPipeline(Pipeline):
preprocess_params["timeout"] = timeout preprocess_params["timeout"] = timeout
if top_k is not None: if top_k is not None:
postprocess_params["top_k"] = top_k postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def __call__( def __call__(
self, self,

View File

@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
}, },
) )
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "openai/whisper-tiny"
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
# We can run the pipeline
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
def require_ffmpeg(test_case): def require_ffmpeg(test_case):
""" """

View File

@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10) _ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out) self.assertNotIn(logger_msg, cl.out)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
pipe = pipeline("text-generation", model=model, assistant_model=model)
# We can run the pipeline
prompt = "Hello world"
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})