Pipeline: simple API for assisted generation (#34504)

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-01-08 17:08:02 +00:00 committed by GitHub
parent 3f483beab9
commit 76da6ca034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 172 additions and 18 deletions

View File

@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
<Tip>
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
```python
>>> from transformers import pipeline
>>> import torch
>>> pipe = pipeline(
... "text-generation",
... model="meta-llama/Llama-3.1-8B",
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
... torch_dtype=torch.bfloat16
>>> )
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
>>> pipe_output[0]["generated_text"]
'Once upon a time, 3D printing was a niche technology that was only'
```
</Tip>
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.

View File

@ -347,7 +347,6 @@ class FlaxGenerationMixin:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:

View File

@ -773,7 +773,6 @@ class TFGenerationMixin:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
use_xla = not tf.executing_eagerly()

View File

@ -348,6 +348,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):

View File

@ -33,7 +33,7 @@ from ..dynamic_module_utils import custom_object_save
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto import AutoConfig, AutoTokenizer
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
@ -425,6 +425,62 @@ def get_default_model_and_revision(
return default_models[framework]
def load_assistant_model(
model: "PreTrainedModel",
assistant_model: Optional[Union[str, "PreTrainedModel"]],
assistant_tokenizer: Optional[PreTrainedTokenizer],
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
"""
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
Args:
model ([`PreTrainedModel`]):
The main model that will be used by the pipeline to make predictions.
assistant_model (`str` or [`PreTrainedModel`], *optional*):
The assistant model that will be used by the pipeline to make predictions.
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
The assistant tokenizer that will be used by the pipeline to encode data for the model.
Returns:
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
"""
if not model.can_generate() or assistant_model is None:
return None, None
if not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
)
# If the model is passed as a string, load the model and the corresponding tokenizer
if isinstance(assistant_model, str):
assistant_config = AutoConfig.from_pretrained(assistant_model)
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
else:
loaded_assistant_model = assistant_model
loaded_assistant_tokenizer = assistant_tokenizer
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
# tokenizer
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
same_special_tokens = all(
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
)
if same_vocab_size and same_special_tokens:
loaded_assistant_tokenizer = None
elif loaded_assistant_tokenizer is None:
raise ValueError(
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
"tokenizer."
)
return loaded_assistant_model, loaded_assistant_tokenizer
class PipelineException(Exception):
"""
Raised by a [`Pipeline`] when handling __call__.
@ -925,8 +981,13 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
):
self.model.to(self.device)
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
# as we apply local tweaks to the generation config.
# If the model can generate:
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
# tweaks to the generation config.
# 2 - load the assistant model if it is passed.
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
)
if self.model.can_generate():
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
self.generation_config = copy.deepcopy(self.model.generation_config)

View File

@ -189,7 +189,14 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
return preprocess_params, {}, postprocess_params
forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def __call__(
self,

View File

@ -92,6 +92,12 @@ class ImageToTextPipeline(Pipeline):
)
forward_params.update(generate_kwargs)
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, {}
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):

View File

@ -358,6 +358,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
forward_params = {}
if sequential is not None:
forward_params["sequential"] = sequential
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, {}
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):

View File

@ -106,6 +106,12 @@ class Text2TextGenerationPipeline(Pipeline):
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def check_inputs(self, input_length: int, min_length: int, max_length: int):

View File

@ -1,7 +1,6 @@
import enum
import itertools
import types
import warnings
from typing import Dict
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
@ -194,12 +193,13 @@ class TextGenerationPipeline(Pipeline):
if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
generate_kwargs["eos_token_id"] = stop_sequence_ids
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params

View File

@ -148,10 +148,9 @@ class TextToAudioPipeline(Pipeline):
else:
if len(generate_kwargs):
raise ValueError(
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
For forward-only TTA models, please use `forward_params` instead of of
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
{generate_kwargs.keys()}"""
"You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
"empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
)
output = self.model(**model_inputs, **forward_params)[0]
@ -191,6 +190,12 @@ class TextToAudioPipeline(Pipeline):
forward_params=None,
generate_kwargs=None,
):
if self.assistant_model is not None:
generate_kwargs["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
generate_kwargs["tokenizer"] = self.tokenizer
generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
params = {
"forward_params": forward_params if forward_params else {},
"generate_kwargs": generate_kwargs if generate_kwargs else {},

View File

@ -66,7 +66,15 @@ class VisualQuestionAnsweringPipeline(Pipeline):
preprocess_params["timeout"] = timeout
if top_k is not None:
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
forward_params = {}
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
return preprocess_params, forward_params, postprocess_params
def __call__(
self,

View File

@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
},
)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "openai/whisper-tiny"
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
# We can run the pipeline
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
def require_ffmpeg(test_case):
"""

View File

@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
pipe = pipeline("text-generation", model=model, assistant_model=model)
# We can run the pipeline
prompt = "Hello world"
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})