mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
f0aeb1be17
commit
228792a9dc
@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions
|
||||
['Les fichiers de configuration sont faciles à utiliser !']
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance
|
||||
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and
|
||||
`end()` is used to flag the end of text generation.
|
||||
|
||||
In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes
|
||||
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into
|
||||
your screen, one word at a time:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||
|
||||
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||
>>> streamer = TextStreamer(tok)
|
||||
|
||||
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
||||
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
||||
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||
```
|
||||
|
||||
## Decoding strategies
|
||||
|
||||
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
|
||||
|
@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
[[autodoc]] top_k_top_p_filtering
|
||||
|
||||
[[autodoc]] tf_top_k_top_p_filtering
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -24,7 +24,8 @@ of the generation method.
|
||||
|
||||
To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
|
||||
and how to create and save a customized generation configuration, refer to the
|
||||
[text generation strategies guide](../generation_strategies).
|
||||
[text generation strategies guide](../generation_strategies). The guide also explains how to use related features,
|
||||
like token streaming.
|
||||
|
||||
## GenerationConfig
|
||||
|
||||
|
@ -96,7 +96,7 @@ _import_structure = {
|
||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||
"file_utils": [],
|
||||
"generation": ["GenerationConfig"],
|
||||
"generation": ["GenerationConfig", "TextStreamer"],
|
||||
"hf_argparser": ["HfArgumentParser"],
|
||||
"image_transforms": [],
|
||||
"integrations": [
|
||||
@ -3769,7 +3769,7 @@ if TYPE_CHECKING:
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
|
||||
# Generation
|
||||
from .generation import GenerationConfig
|
||||
from .generation import GenerationConfig, TextStreamer
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
|
@ -17,8 +17,7 @@ from typing import TYPE_CHECKING
|
||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_utils": ["GenerationConfig"]}
|
||||
|
||||
_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@ -150,6 +149,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .streamers import TextStreamer
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
104
src/transformers/generation/streamers.py
Normal file
104
src/transformers/generation/streamers.py
Normal file
@ -0,0 +1,104 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.auto import AutoTokenizer
|
||||
|
||||
|
||||
class BaseStreamer:
|
||||
"""
|
||||
Base class from which `.generate()` streamers should inherit.
|
||||
"""
|
||||
|
||||
def put(self, value):
|
||||
"""Function that is called by `.generate()` to push new tokens"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def end(self):
|
||||
"""Function that is called by `.generate()` to signal the end of generation"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TextStreamer(BaseStreamer):
|
||||
"""
|
||||
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
||||
|
||||
Parameters:
|
||||
tokenizer (`AutoTokenizer`):
|
||||
The tokenized used to decode the tokens.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||
|
||||
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||
>>> streamer = TextStreamer(tok)
|
||||
|
||||
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
||||
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
||||
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AutoTokenizer"):
|
||||
self.tokenizer = tokenizer
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
|
||||
def put(self, value):
|
||||
"""
|
||||
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||
"""
|
||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||
raise ValueError("TextStreamer only supports batch size 1")
|
||||
elif len(value.shape) > 1:
|
||||
value = value[0]
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.extend(value.tolist())
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
|
||||
# After symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||
else:
|
||||
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
||||
self.print_len += len(printable_text)
|
||||
|
||||
print(printable_text, flush=True, end="")
|
||||
|
||||
def end(self):
|
||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
else:
|
||||
printable_text = ""
|
||||
|
||||
# Print a newline (and the remaining text, if any)
|
||||
print(printable_text, flush=True)
|
@ -18,7 +18,7 @@ import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -72,6 +72,10 @@ from .stopping_criteria import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .streamers import BaseStreamer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1116,6 +1120,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -1165,6 +1170,9 @@ class GenerationMixin:
|
||||
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
@ -1295,6 +1303,9 @@ class GenerationMixin:
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
@ -1335,7 +1346,8 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
is_contrastive_search_gen_mode = (
|
||||
generation_config.top_k is not None
|
||||
(generation_config.num_beams == 1)
|
||||
and generation_config.top_k is not None
|
||||
and generation_config.top_k > 1
|
||||
and generation_config.do_sample is False
|
||||
and generation_config.penalty_alpha is not None
|
||||
@ -1384,6 +1396,11 @@ class GenerationMixin:
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
raise ValueError(
|
||||
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||
)
|
||||
|
||||
if self.device.type != input_ids.device.type:
|
||||
warnings.warn(
|
||||
"You are calling .generate() with the `input_ids` being on a device type different"
|
||||
@ -1426,6 +1443,7 @@ class GenerationMixin:
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1447,6 +1465,7 @@ class GenerationMixin:
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1473,6 +1492,7 @@ class GenerationMixin:
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1703,6 +1723,7 @@ class GenerationMixin:
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -1750,6 +1771,9 @@ class GenerationMixin:
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2010,6 +2034,8 @@ class GenerationMixin:
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
@ -2027,6 +2053,9 @@ class GenerationMixin:
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return ContrastiveSearchEncoderDecoderOutput(
|
||||
@ -2061,6 +2090,7 @@ class GenerationMixin:
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2105,6 +2135,9 @@ class GenerationMixin:
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2256,6 +2289,8 @@ class GenerationMixin:
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
@ -2273,6 +2308,9 @@ class GenerationMixin:
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return GreedySearchEncoderDecoderOutput(
|
||||
@ -2308,6 +2346,7 @@ class GenerationMixin:
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[SampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2354,6 +2393,9 @@ class GenerationMixin:
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2525,6 +2567,8 @@ class GenerationMixin:
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
@ -2542,6 +2586,9 @@ class GenerationMixin:
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return SampleEncoderDecoderOutput(
|
||||
|
44
tests/generation/test_streamers.py
Normal file
44
tests/generation/test_streamers.py
Normal file
@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, TextStreamer, is_torch_available
|
||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class StreamerTester(unittest.TestCase):
|
||||
def test_text_streamer_stdout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
self.assertEqual(cs.out[:-1], greedy_text)
|
Loading…
Reference in New Issue
Block a user