diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a54ac432006..d8931342ee4 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -352,6 +352,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] TextIteratorStreamer +[[autodoc]] AsyncTextIteratorStreamer + ## Caches [[autodoc]] Cache diff --git a/setup.py b/setup.py index c2c0048d691..9e678db9978 100644 --- a/setup.py +++ b/setup.py @@ -148,6 +148,7 @@ _deps = [ "pyyaml>=5.1", "pydantic", "pytest>=7.2.0,<8.0.0", + "pytest-asyncio", "pytest-timeout", "pytest-xdist", "python>=3.9.0", @@ -319,6 +320,7 @@ extras["tiktoken"] = deps_list("tiktoken", "blobfile") extras["testing"] = ( deps_list( "pytest", + "pytest-asyncio", "pytest-rich", "pytest-xdist", "timeout-decorator", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 681bf1a5d16..5510ac6c8ad 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -122,6 +122,7 @@ _import_structure = { "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], "generation": [ + "AsyncTextIteratorStreamer", "CompileConfig", "GenerationConfig", "TextIteratorStreamer", @@ -5055,7 +5056,14 @@ if TYPE_CHECKING: from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig + from .generation import ( + AsyncTextIteratorStreamer, + CompileConfig, + GenerationConfig, + TextIteratorStreamer, + TextStreamer, + WatermarkingConfig, + ) from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 85345cc8e58..c370c7a5d7c 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -54,6 +54,7 @@ deps = { "pyyaml": "pyyaml>=5.1", "pydantic": "pydantic", "pytest": "pytest>=7.2.0,<8.0.0", + "pytest-asyncio": "pytest-asyncio", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.9.0", diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 59d970db154..d3eb10c1e6b 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -26,7 +26,7 @@ _import_structure = { "SynthIDTextWatermarkingConfig", "WatermarkingConfig", ], - "streamers": ["TextIteratorStreamer", "TextStreamer"], + "streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"], } try: @@ -199,7 +199,7 @@ if TYPE_CHECKING: SynthIDTextWatermarkingConfig, WatermarkingConfig, ) - from .streamers import TextIteratorStreamer, TextStreamer + from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer try: if not is_torch_available(): diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af..c78e259db38 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from queue import Queue from typing import TYPE_CHECKING, Optional @@ -225,3 +226,91 @@ class TextIteratorStreamer(TextStreamer): raise StopIteration() else: return value + + +class AsyncTextIteratorStreamer(TextStreamer): + """ + Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator. + This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an + interactive Gradio demo). + + + + The API for the streamer classes is still under development and may change in the future. + + + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + timeout (`float`, *optional*): + The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions + in `.generate()`, when it is called in a separate thread. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + + Raises: + TimeoutError: If token generation time exceeds timeout value. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer + >>> from threading import Thread + >>> import asyncio + + >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") + + >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. + >>> async def main(): + ... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine! + ... streamer = AsyncTextIteratorStreamer(tok) + ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + ... thread = Thread(target=model.generate, kwargs=generation_kwargs) + ... thread.start() + ... generated_text = "" + ... async for new_text in streamer: + ... generated_text += new_text + >>> print(generated_text) + >>> asyncio.run(main()) + An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, + ``` + """ + + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = asyncio.Queue() + self.stop_signal = None + self.timeout = timeout + self.loop = asyncio.get_running_loop() + self.has_asyncio_timeout = hasattr(asyncio, "timeout") + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text) + if stream_end: + self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + if self.has_asyncio_timeout: + async with asyncio.timeout(self.timeout): + value = await self.text_queue.get() + else: + value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError() + else: + if value == self.stop_signal: + raise StopAsyncIteration() + else: + return value diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index c82a5e99e0d..be8c37334d0 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -17,7 +17,15 @@ import unittest from queue import Empty from threading import Thread -from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available +import pytest + +from transformers import ( + AsyncTextIteratorStreamer, + AutoTokenizer, + TextIteratorStreamer, + TextStreamer, + is_torch_available, +) from transformers.testing_utils import CaptureStdout, require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -120,3 +128,43 @@ class StreamerTester(unittest.TestCase): streamer_text = "" for new_text in streamer: streamer_text += new_text + + +@require_torch +@pytest.mark.asyncio(loop_scope="class") +class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase): + async def test_async_iterator_streamer_matches_non_streaming(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False) + greedy_text = tokenizer.decode(greedy_ids[0]) + + streamer = AsyncTextIteratorStreamer(tokenizer) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + streamer_text = "" + async for new_text in streamer: + streamer_text += new_text + + self.assertEqual(streamer_text, greedy_text) + + async def test_async_iterator_streamer_timeout(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + # The streamer will timeout after 0.001 seconds, so TimeoutError will be raised + with self.assertRaises(TimeoutError): + streamer_text = "" + async for new_text in streamer: + streamer_text += new_text