From 228792a9dc0c36f1e82ab441e1b1991d116ee0a0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 30 Mar 2023 12:00:12 +0100 Subject: [PATCH] Generate: basic token streaming (#22449) * haha tokens go brrrr --- docs/source/en/generation_strategies.mdx | 23 ++++ docs/source/en/internal/generation_utils.mdx | 4 + .../en/main_classes/text_generation.mdx | 3 +- src/transformers/__init__.py | 4 +- src/transformers/generation/__init__.py | 4 +- src/transformers/generation/streamers.py | 104 ++++++++++++++++++ src/transformers/generation/utils.py | 51 ++++++++- tests/generation/test_streamers.py | 44 ++++++++ 8 files changed, 230 insertions(+), 7 deletions(-) create mode 100644 src/transformers/generation/streamers.py create mode 100644 tests/generation/test_streamers.py diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index 44692169405..831c8772b6c 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions ['Les fichiers de configuration sont faciles à utiliser !'] ``` +## Streaming + +The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance +from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and +`end()` is used to flag the end of text generation. + +In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes +ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into +your screen, one word at a time: + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer + +>>> tok = AutoTokenizer.from_pretrained("gpt2") +>>> model = AutoModelForCausalLM.from_pretrained("gpt2") +>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") +>>> streamer = TextStreamer(tok) + +>>> # Despite returning the usual output, the streamer will also print the generated text to stdout. +>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) +An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, +``` + ## Decoding strategies Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index 3c86b7dc3f0..dd93c79e92b 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] top_k_top_p_filtering [[autodoc]] tf_top_k_top_p_filtering + +## Streamers + +[[autodoc]] TextStreamer diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 5351129cbb1..39a15160346 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -24,7 +24,8 @@ of the generation method. To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc, and how to create and save a customized generation configuration, refer to the -[text generation strategies guide](../generation_strategies). +[text generation strategies guide](../generation_strategies). The guide also explains how to use related features, +like token streaming. ## GenerationConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 39ae4e97fd3..ad1b335ed3a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -96,7 +96,7 @@ _import_structure = { "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], - "generation": ["GenerationConfig"], + "generation": ["GenerationConfig", "TextStreamer"], "hf_argparser": ["HfArgumentParser"], "image_transforms": [], "integrations": [ @@ -3769,7 +3769,7 @@ if TYPE_CHECKING: from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import GenerationConfig + from .generation import GenerationConfig, TextStreamer from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index a5f4aa01491..d163c44dc7f 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -17,8 +17,7 @@ from typing import TYPE_CHECKING from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available -_import_structure = {"configuration_utils": ["GenerationConfig"]} - +_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]} try: if not is_torch_available(): @@ -150,6 +149,7 @@ else: if TYPE_CHECKING: from .configuration_utils import GenerationConfig + from .streamers import TextStreamer try: if not is_torch_available(): diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py new file mode 100644 index 00000000000..d110693b0ea --- /dev/null +++ b/src/transformers/generation/streamers.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from ..models.auto import AutoTokenizer + + +class BaseStreamer: + """ + Base class from which `.generate()` streamers should inherit. + """ + + def put(self, value): + """Function that is called by `.generate()` to push new tokens""" + raise NotImplementedError() + + def end(self): + """Function that is called by `.generate()` to signal the end of generation""" + raise NotImplementedError() + + +class TextStreamer(BaseStreamer): + """ + Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer + + >>> tok = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") + >>> streamer = TextStreamer(tok) + + >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. + >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) + An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, + ``` + """ + + def __init__(self, tokenizer: "AutoTokenizer"): + self.tokenizer = tokenizer + self.token_cache = [] + self.print_len = 0 + + def put(self, value): + """ + Recives tokens, decodes them, and prints them to stdout as soon as they form entire words. + """ + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + # Add the new token to the cache and decodes the entire thing. + self.token_cache.extend(value.tolist()) + text = self.tokenizer.decode(self.token_cache) + + # After symbol for a new line, we flush the cache. + if text.endswith("\n"): + printable_text = text[self.print_len :] + self.token_cache = [] + self.print_len = 0 + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, + # which may change with the subsequent token -- there are probably smarter ways to do this!) + else: + printable_text = text[self.print_len : text.rfind(" ") + 1] + self.print_len += len(printable_text) + + print(printable_text, flush=True, end="") + + def end(self): + """Flushes any remaining cache and prints a newline to stdout.""" + # Flush the cache, if it exists + if len(self.token_cache) > 0: + text = self.tokenizer.decode(self.token_cache) + printable_text = text[self.print_len :] + self.token_cache = [] + self.print_len = 0 + else: + printable_text = "" + + # Print a newline (and the remaining text, if any) + print(printable_text, flush=True) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ee92e51a19a..ae12ae2930f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -18,7 +18,7 @@ import copy import inspect import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -72,6 +72,10 @@ from .stopping_criteria import ( ) +if TYPE_CHECKING: + from .streamers import BaseStreamer + + logger = logging.get_logger(__name__) @@ -1116,6 +1120,7 @@ class GenerationMixin: stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -1165,6 +1170,9 @@ class GenerationMixin: Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be @@ -1295,6 +1303,9 @@ class GenerationMixin: else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + if streamer is not None: + streamer.put(input_ids.cpu()) + # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None @@ -1335,7 +1346,8 @@ class GenerationMixin: ) is_contrastive_search_gen_mode = ( - generation_config.top_k is not None + (generation_config.num_beams == 1) + and generation_config.top_k is not None and generation_config.top_k > 1 and generation_config.do_sample is False and generation_config.penalty_alpha is not None @@ -1384,6 +1396,11 @@ class GenerationMixin: "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) + if streamer is not None and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" @@ -1426,6 +1443,7 @@ class GenerationMixin: output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, + streamer=streamer, **model_kwargs, ) @@ -1447,6 +1465,7 @@ class GenerationMixin: output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, + streamer=streamer, **model_kwargs, ) @@ -1473,6 +1492,7 @@ class GenerationMixin: output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, + streamer=streamer, **model_kwargs, ) @@ -1703,6 +1723,7 @@ class GenerationMixin: output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1750,6 +1771,9 @@ class GenerationMixin: Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2010,6 +2034,8 @@ class GenerationMixin: # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -2027,6 +2053,9 @@ class GenerationMixin: else: this_peer_finished = True + if streamer is not None: + streamer.end() + if return_dict_in_generate: if self.config.is_encoder_decoder: return ContrastiveSearchEncoderDecoderOutput( @@ -2061,6 +2090,7 @@ class GenerationMixin: output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" @@ -2105,6 +2135,9 @@ class GenerationMixin: Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2256,6 +2289,8 @@ class GenerationMixin: # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -2273,6 +2308,9 @@ class GenerationMixin: else: this_peer_finished = True + if streamer is not None: + streamer.end() + if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( @@ -2308,6 +2346,7 @@ class GenerationMixin: output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" @@ -2354,6 +2393,9 @@ class GenerationMixin: Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2525,6 +2567,8 @@ class GenerationMixin: # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -2542,6 +2586,9 @@ class GenerationMixin: else: this_peer_finished = True + if streamer is not None: + streamer.end() + if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py new file mode 100644 index 00000000000..12062328590 --- /dev/null +++ b/tests/generation/test_streamers.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import AutoTokenizer, TextStreamer, is_torch_available +from transformers.testing_utils import CaptureStdout, require_torch, torch_device + +from ..test_modeling_common import ids_tensor + + +if is_torch_available(): + from transformers import AutoModelForCausalLM + + +@require_torch +class StreamerTester(unittest.TestCase): + def test_text_streamer_stdout(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False) + greedy_text = tokenizer.decode(greedy_ids[0]) + + with CaptureStdout() as cs: + streamer = TextStreamer(tokenizer) + model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer) + + # The greedy text should be printed to stdout, except for the final "\n" in the streamer + self.assertEqual(cs.out[:-1], greedy_text)