mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Pipeline: simple API for assisted generation (#34504)
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
3f483beab9
commit
76da6ca034
@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
|||||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> pipe = pipeline(
|
||||||
|
... "text-generation",
|
||||||
|
... model="meta-llama/Llama-3.1-8B",
|
||||||
|
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
|
||||||
|
... torch_dtype=torch.bfloat16
|
||||||
|
>>> )
|
||||||
|
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
|
||||||
|
>>> pipe_output[0]["generated_text"]
|
||||||
|
'Once upon a time, 3D printing was a niche technology that was only'
|
||||||
|
```
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
|
||||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||||
|
|
||||||
|
@ -347,7 +347,6 @@ class FlaxGenerationMixin:
|
|||||||
eos_token_id = generation_config.eos_token_id
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, list):
|
if isinstance(eos_token_id, list):
|
||||||
eos_token_id = eos_token_id[0]
|
eos_token_id = eos_token_id[0]
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
|
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||||
|
@ -773,7 +773,6 @@ class TFGenerationMixin:
|
|||||||
eos_token_id = generation_config.eos_token_id
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, list):
|
if isinstance(eos_token_id, list):
|
||||||
eos_token_id = eos_token_id[0]
|
eos_token_id = eos_token_id[0]
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
use_xla = not tf.executing_eagerly()
|
use_xla = not tf.executing_eagerly()
|
||||||
|
@ -348,6 +348,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
raise ValueError("Only Whisper can return language for now.")
|
raise ValueError("Only Whisper can return language for now.")
|
||||||
postprocess_params["return_language"] = return_language
|
postprocess_params["return_language"] = return_language
|
||||||
|
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||||
|
@ -33,7 +33,7 @@ from ..dynamic_module_utils import custom_object_save
|
|||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
from ..image_processing_utils import BaseImageProcessor
|
from ..image_processing_utils import BaseImageProcessor
|
||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto import AutoConfig, AutoTokenizer
|
||||||
from ..processing_utils import ProcessorMixin
|
from ..processing_utils import ProcessorMixin
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
@ -425,6 +425,62 @@ def get_default_model_and_revision(
|
|||||||
return default_models[framework]
|
return default_models[framework]
|
||||||
|
|
||||||
|
|
||||||
|
def load_assistant_model(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
assistant_model: Optional[Union[str, "PreTrainedModel"]],
|
||||||
|
assistant_tokenizer: Optional[PreTrainedTokenizer],
|
||||||
|
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
|
||||||
|
"""
|
||||||
|
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model ([`PreTrainedModel`]):
|
||||||
|
The main model that will be used by the pipeline to make predictions.
|
||||||
|
assistant_model (`str` or [`PreTrainedModel`], *optional*):
|
||||||
|
The assistant model that will be used by the pipeline to make predictions.
|
||||||
|
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
|
||||||
|
The assistant tokenizer that will be used by the pipeline to encode data for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
|
||||||
|
"""
|
||||||
|
if not model.can_generate() or assistant_model is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not isinstance(model, PreTrainedModel):
|
||||||
|
raise ValueError(
|
||||||
|
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
|
||||||
|
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the model is passed as a string, load the model and the corresponding tokenizer
|
||||||
|
if isinstance(assistant_model, str):
|
||||||
|
assistant_config = AutoConfig.from_pretrained(assistant_model)
|
||||||
|
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
|
||||||
|
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
|
||||||
|
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
|
||||||
|
else:
|
||||||
|
loaded_assistant_model = assistant_model
|
||||||
|
loaded_assistant_tokenizer = assistant_tokenizer
|
||||||
|
|
||||||
|
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
|
||||||
|
# tokenizer
|
||||||
|
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
|
||||||
|
same_special_tokens = all(
|
||||||
|
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
|
||||||
|
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
|
||||||
|
)
|
||||||
|
if same_vocab_size and same_special_tokens:
|
||||||
|
loaded_assistant_tokenizer = None
|
||||||
|
elif loaded_assistant_tokenizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
|
||||||
|
"tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
return loaded_assistant_model, loaded_assistant_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class PipelineException(Exception):
|
class PipelineException(Exception):
|
||||||
"""
|
"""
|
||||||
Raised by a [`Pipeline`] when handling __call__.
|
Raised by a [`Pipeline`] when handling __call__.
|
||||||
@ -925,8 +981,13 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
|||||||
):
|
):
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
|
# If the model can generate:
|
||||||
# as we apply local tweaks to the generation config.
|
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
|
||||||
|
# tweaks to the generation config.
|
||||||
|
# 2 - load the assistant model if it is passed.
|
||||||
|
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
|
||||||
|
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
|
||||||
|
)
|
||||||
if self.model.can_generate():
|
if self.model.can_generate():
|
||||||
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
||||||
self.generation_config = copy.deepcopy(self.model.generation_config)
|
self.generation_config = copy.deepcopy(self.model.generation_config)
|
||||||
|
@ -189,7 +189,14 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
if handle_impossible_answer is not None:
|
if handle_impossible_answer is not None:
|
||||||
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
||||||
|
|
||||||
return preprocess_params, {}, postprocess_params
|
forward_params = {}
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -92,6 +92,12 @@ class ImageToTextPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
forward_params.update(generate_kwargs)
|
forward_params.update(generate_kwargs)
|
||||||
|
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
return preprocess_params, forward_params, {}
|
return preprocess_params, forward_params, {}
|
||||||
|
|
||||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||||
|
@ -358,6 +358,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
forward_params = {}
|
forward_params = {}
|
||||||
if sequential is not None:
|
if sequential is not None:
|
||||||
forward_params["sequential"] = sequential
|
forward_params["sequential"] = sequential
|
||||||
|
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
return preprocess_params, forward_params, {}
|
return preprocess_params, forward_params, {}
|
||||||
|
|
||||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
||||||
|
@ -106,6 +106,12 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
||||||
|
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
import types
|
import types
|
||||||
import warnings
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
@ -194,12 +193,13 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
if stop_sequence is not None:
|
if stop_sequence is not None:
|
||||||
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
||||||
if len(stop_sequence_ids) > 1:
|
generate_kwargs["eos_token_id"] = stop_sequence_ids
|
||||||
warnings.warn(
|
|
||||||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
|
if self.assistant_model is not None:
|
||||||
" the stop sequence will be used as the stop sequence string in the interim."
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
)
|
if self.assistant_tokenizer is not None:
|
||||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
|
@ -148,10 +148,9 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
if len(generate_kwargs):
|
if len(generate_kwargs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
|
"You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
|
||||||
For forward-only TTA models, please use `forward_params` instead of of
|
"empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
|
||||||
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
|
f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
|
||||||
{generate_kwargs.keys()}"""
|
|
||||||
)
|
)
|
||||||
output = self.model(**model_inputs, **forward_params)[0]
|
output = self.model(**model_inputs, **forward_params)[0]
|
||||||
|
|
||||||
@ -191,6 +190,12 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
forward_params=None,
|
forward_params=None,
|
||||||
generate_kwargs=None,
|
generate_kwargs=None,
|
||||||
):
|
):
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
generate_kwargs["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
generate_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"forward_params": forward_params if forward_params else {},
|
"forward_params": forward_params if forward_params else {},
|
||||||
"generate_kwargs": generate_kwargs if generate_kwargs else {},
|
"generate_kwargs": generate_kwargs if generate_kwargs else {},
|
||||||
|
@ -66,7 +66,15 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
|||||||
preprocess_params["timeout"] = timeout
|
preprocess_params["timeout"] = timeout
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
postprocess_params["top_k"] = top_k
|
postprocess_params["top_k"] = top_k
|
||||||
return preprocess_params, {}, postprocess_params
|
|
||||||
|
forward_params = {}
|
||||||
|
if self.assistant_model is not None:
|
||||||
|
forward_params["assistant_model"] = self.assistant_model
|
||||||
|
if self.assistant_tokenizer is not None:
|
||||||
|
forward_params["tokenizer"] = self.tokenizer
|
||||||
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
||||||
|
|
||||||
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_assisted_generation(self):
|
||||||
|
"""Tests that we can run assisted generation in the pipeline"""
|
||||||
|
model = "openai/whisper-tiny"
|
||||||
|
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
|
||||||
|
|
||||||
|
# We can run the pipeline
|
||||||
|
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
|
||||||
|
_ = pipe(prompt)
|
||||||
|
|
||||||
|
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||||
|
|
||||||
|
|
||||||
def require_ffmpeg(test_case):
|
def require_ffmpeg(test_case):
|
||||||
"""
|
"""
|
||||||
|
@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
_ = text_generator(prompt, max_length=10)
|
_ = text_generator(prompt, max_length=10)
|
||||||
self.assertNotIn(logger_msg, cl.out)
|
self.assertNotIn(logger_msg, cl.out)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_assisted_generation(self):
|
||||||
|
"""Tests that we can run assisted generation in the pipeline"""
|
||||||
|
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||||
|
pipe = pipeline("text-generation", model=model, assistant_model=model)
|
||||||
|
|
||||||
|
# We can run the pipeline
|
||||||
|
prompt = "Hello world"
|
||||||
|
_ = pipe(prompt)
|
||||||
|
|
||||||
|
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||||
|
Loading…
Reference in New Issue
Block a user