mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline: simple API for assisted generation (#34504)
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
3f483beab9
commit
76da6ca034
@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = pipeline(
|
||||
... "text-generation",
|
||||
... model="meta-llama/Llama-3.1-8B",
|
||||
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
|
||||
... torch_dtype=torch.bfloat16
|
||||
>>> )
|
||||
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
|
||||
>>> pipe_output[0]["generated_text"]
|
||||
'Once upon a time, 3D printing was a niche technology that was only'
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||
|
||||
|
@ -347,7 +347,6 @@ class FlaxGenerationMixin:
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||
|
@ -773,7 +773,6 @@ class TFGenerationMixin:
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
use_xla = not tf.executing_eagerly()
|
||||
|
@ -348,6 +348,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("Only Whisper can return language for now.")
|
||||
postprocess_params["return_language"] = return_language
|
||||
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||
|
@ -33,7 +33,7 @@ from ..dynamic_module_utils import custom_object_save
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..image_processing_utils import BaseImageProcessor
|
||||
from ..modelcard import ModelCard
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..models.auto import AutoConfig, AutoTokenizer
|
||||
from ..processing_utils import ProcessorMixin
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import (
|
||||
@ -425,6 +425,62 @@ def get_default_model_and_revision(
|
||||
return default_models[framework]
|
||||
|
||||
|
||||
def load_assistant_model(
|
||||
model: "PreTrainedModel",
|
||||
assistant_model: Optional[Union[str, "PreTrainedModel"]],
|
||||
assistant_tokenizer: Optional[PreTrainedTokenizer],
|
||||
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
|
||||
"""
|
||||
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
|
||||
|
||||
Args:
|
||||
model ([`PreTrainedModel`]):
|
||||
The main model that will be used by the pipeline to make predictions.
|
||||
assistant_model (`str` or [`PreTrainedModel`], *optional*):
|
||||
The assistant model that will be used by the pipeline to make predictions.
|
||||
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
|
||||
The assistant tokenizer that will be used by the pipeline to encode data for the model.
|
||||
|
||||
Returns:
|
||||
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
|
||||
"""
|
||||
if not model.can_generate() or assistant_model is None:
|
||||
return None, None
|
||||
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError(
|
||||
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
|
||||
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
|
||||
)
|
||||
|
||||
# If the model is passed as a string, load the model and the corresponding tokenizer
|
||||
if isinstance(assistant_model, str):
|
||||
assistant_config = AutoConfig.from_pretrained(assistant_model)
|
||||
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
|
||||
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
|
||||
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
|
||||
else:
|
||||
loaded_assistant_model = assistant_model
|
||||
loaded_assistant_tokenizer = assistant_tokenizer
|
||||
|
||||
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
|
||||
# tokenizer
|
||||
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
|
||||
same_special_tokens = all(
|
||||
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
|
||||
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
|
||||
)
|
||||
if same_vocab_size and same_special_tokens:
|
||||
loaded_assistant_tokenizer = None
|
||||
elif loaded_assistant_tokenizer is None:
|
||||
raise ValueError(
|
||||
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
return loaded_assistant_model, loaded_assistant_tokenizer
|
||||
|
||||
|
||||
class PipelineException(Exception):
|
||||
"""
|
||||
Raised by a [`Pipeline`] when handling __call__.
|
||||
@ -925,8 +981,13 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
):
|
||||
self.model.to(self.device)
|
||||
|
||||
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
|
||||
# as we apply local tweaks to the generation config.
|
||||
# If the model can generate:
|
||||
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
|
||||
# tweaks to the generation config.
|
||||
# 2 - load the assistant model if it is passed.
|
||||
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
|
||||
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
|
||||
)
|
||||
if self.model.can_generate():
|
||||
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
||||
self.generation_config = copy.deepcopy(self.model.generation_config)
|
||||
|
@ -189,7 +189,14 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
if handle_impossible_answer is not None:
|
||||
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
forward_params = {}
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -92,6 +92,12 @@ class ImageToTextPipeline(Pipeline):
|
||||
)
|
||||
forward_params.update(generate_kwargs)
|
||||
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
|
@ -358,6 +358,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
forward_params = {}
|
||||
if sequential is not None:
|
||||
forward_params["sequential"] = sequential
|
||||
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
||||
|
@ -106,6 +106,12 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
)
|
||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
||||
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||
|
@ -1,7 +1,6 @@
|
||||
import enum
|
||||
import itertools
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
@ -194,12 +193,13 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
if stop_sequence is not None:
|
||||
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
||||
if len(stop_sequence_ids) > 1:
|
||||
warnings.warn(
|
||||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
|
||||
" the stop sequence will be used as the stop sequence string in the interim."
|
||||
)
|
||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
||||
generate_kwargs["eos_token_id"] = stop_sequence_ids
|
||||
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
|
@ -148,10 +148,9 @@ class TextToAudioPipeline(Pipeline):
|
||||
else:
|
||||
if len(generate_kwargs):
|
||||
raise ValueError(
|
||||
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
|
||||
For forward-only TTA models, please use `forward_params` instead of of
|
||||
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
|
||||
{generate_kwargs.keys()}"""
|
||||
"You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
|
||||
"empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
|
||||
f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
|
||||
)
|
||||
output = self.model(**model_inputs, **forward_params)[0]
|
||||
|
||||
@ -191,6 +190,12 @@ class TextToAudioPipeline(Pipeline):
|
||||
forward_params=None,
|
||||
generate_kwargs=None,
|
||||
):
|
||||
if self.assistant_model is not None:
|
||||
generate_kwargs["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
generate_kwargs["tokenizer"] = self.tokenizer
|
||||
generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
params = {
|
||||
"forward_params": forward_params if forward_params else {},
|
||||
"generate_kwargs": generate_kwargs if generate_kwargs else {},
|
||||
|
@ -66,7 +66,15 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
preprocess_params["timeout"] = timeout
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
forward_params = {}
|
||||
if self.assistant_model is not None:
|
||||
forward_params["assistant_model"] = self.assistant_model
|
||||
if self.assistant_tokenizer is not None:
|
||||
forward_params["tokenizer"] = self.tokenizer
|
||||
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
model = "openai/whisper-tiny"
|
||||
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
|
||||
|
||||
# We can run the pipeline
|
||||
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
|
||||
_ = pipe(prompt)
|
||||
|
||||
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||
|
||||
|
||||
def require_ffmpeg(test_case):
|
||||
"""
|
||||
|
@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
pipe = pipeline("text-generation", model=model, assistant_model=model)
|
||||
|
||||
# We can run the pipeline
|
||||
prompt = "Hello world"
|
||||
_ = pipe(prompt)
|
||||
|
||||
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||
|
Loading…
Reference in New Issue
Block a user