Implement AsyncTextIteratorStreamer for asynchronous streaming (#34931)

* Add AsyncTextIteratorStreamer class

* export AsyncTextIteratorStreamer

* export AsyncTextIteratorStreamer

* improve docs

* missing import

* missing import

* doc example fix

* doc example output fix

* add pytest-asyncio

* first attempt at tests

* missing import

* add pytest-asyncio

* fallback to wait_for and raise TimeoutError on timeout

* check for TimeoutError

* autodoc

* reorder imports

* fix style

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sigbjørn Skjæret 2024-12-20 12:08:12 +01:00 committed by GitHub
parent b5a557e5fe
commit eafbb0eca7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 154 additions and 4 deletions

View File

@ -352,6 +352,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] TextIteratorStreamer [[autodoc]] TextIteratorStreamer
[[autodoc]] AsyncTextIteratorStreamer
## Caches ## Caches
[[autodoc]] Cache [[autodoc]] Cache

View File

@ -148,6 +148,7 @@ _deps = [
"pyyaml>=5.1", "pyyaml>=5.1",
"pydantic", "pydantic",
"pytest>=7.2.0,<8.0.0", "pytest>=7.2.0,<8.0.0",
"pytest-asyncio",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.9.0", "python>=3.9.0",
@ -319,6 +320,7 @@ extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["testing"] = ( extras["testing"] = (
deps_list( deps_list(
"pytest", "pytest",
"pytest-asyncio",
"pytest-rich", "pytest-rich",
"pytest-xdist", "pytest-xdist",
"timeout-decorator", "timeout-decorator",

View File

@ -122,6 +122,7 @@ _import_structure = {
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [], "file_utils": [],
"generation": [ "generation": [
"AsyncTextIteratorStreamer",
"CompileConfig", "CompileConfig",
"GenerationConfig", "GenerationConfig",
"TextIteratorStreamer", "TextIteratorStreamer",
@ -5055,7 +5056,14 @@ if TYPE_CHECKING:
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation # Generation
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig from .generation import (
AsyncTextIteratorStreamer,
CompileConfig,
GenerationConfig,
TextIteratorStreamer,
TextStreamer,
WatermarkingConfig,
)
from .hf_argparser import HfArgumentParser from .hf_argparser import HfArgumentParser
# Integrations # Integrations

View File

@ -54,6 +54,7 @@ deps = {
"pyyaml": "pyyaml>=5.1", "pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic", "pydantic": "pydantic",
"pytest": "pytest>=7.2.0,<8.0.0", "pytest": "pytest>=7.2.0,<8.0.0",
"pytest-asyncio": "pytest-asyncio",
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.9.0", "python": "python>=3.9.0",

View File

@ -26,7 +26,7 @@ _import_structure = {
"SynthIDTextWatermarkingConfig", "SynthIDTextWatermarkingConfig",
"WatermarkingConfig", "WatermarkingConfig",
], ],
"streamers": ["TextIteratorStreamer", "TextStreamer"], "streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
} }
try: try:
@ -199,7 +199,7 @@ if TYPE_CHECKING:
SynthIDTextWatermarkingConfig, SynthIDTextWatermarkingConfig,
WatermarkingConfig, WatermarkingConfig,
) )
from .streamers import TextIteratorStreamer, TextStreamer from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
try: try:
if not is_torch_available(): if not is_torch_available():

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
from queue import Queue from queue import Queue
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
@ -225,3 +226,91 @@ class TextIteratorStreamer(TextStreamer):
raise StopIteration() raise StopIteration()
else: else:
return value return value
class AsyncTextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an
interactive Gradio demo).
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Raises:
TimeoutError: If token generation time exceeds timeout value.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
>>> from threading import Thread
>>> import asyncio
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
>>> async def main():
... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
... streamer = AsyncTextIteratorStreamer(tok)
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
... thread.start()
... generated_text = ""
... async for new_text in streamer:
... generated_text += new_text
>>> print(generated_text)
>>> asyncio.run(main())
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = asyncio.Queue()
self.stop_signal = None
self.timeout = timeout
self.loop = asyncio.get_running_loop()
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
if stream_end:
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
def __aiter__(self):
return self
async def __anext__(self):
try:
if self.has_asyncio_timeout:
async with asyncio.timeout(self.timeout):
value = await self.text_queue.get()
else:
value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
except asyncio.TimeoutError:
raise TimeoutError()
else:
if value == self.stop_signal:
raise StopAsyncIteration()
else:
return value

View File

@ -17,7 +17,15 @@ import unittest
from queue import Empty from queue import Empty
from threading import Thread from threading import Thread
from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available import pytest
from transformers import (
AsyncTextIteratorStreamer,
AutoTokenizer,
TextIteratorStreamer,
TextStreamer,
is_torch_available,
)
from transformers.testing_utils import CaptureStdout, require_torch, torch_device from transformers.testing_utils import CaptureStdout, require_torch, torch_device
from ..test_modeling_common import ids_tensor from ..test_modeling_common import ids_tensor
@ -120,3 +128,43 @@ class StreamerTester(unittest.TestCase):
streamer_text = "" streamer_text = ""
for new_text in streamer: for new_text in streamer:
streamer_text += new_text streamer_text += new_text
@require_torch
@pytest.mark.asyncio(loop_scope="class")
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
async def test_async_iterator_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])
streamer = AsyncTextIteratorStreamer(tokenizer)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
streamer_text = ""
async for new_text in streamer:
streamer_text += new_text
self.assertEqual(streamer_text, greedy_text)
async def test_async_iterator_streamer_timeout(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
with self.assertRaises(TimeoutError):
streamer_text = ""
async for new_text in streamer:
streamer_text += new_text