mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Implement AsyncTextIteratorStreamer for asynchronous streaming (#34931)
* Add AsyncTextIteratorStreamer class * export AsyncTextIteratorStreamer * export AsyncTextIteratorStreamer * improve docs * missing import * missing import * doc example fix * doc example output fix * add pytest-asyncio * first attempt at tests * missing import * add pytest-asyncio * fallback to wait_for and raise TimeoutError on timeout * check for TimeoutError * autodoc * reorder imports * fix style --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
b5a557e5fe
commit
eafbb0eca7
@ -352,6 +352,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] TextIteratorStreamer
|
||||
|
||||
[[autodoc]] AsyncTextIteratorStreamer
|
||||
|
||||
## Caches
|
||||
|
||||
[[autodoc]] Cache
|
||||
|
2
setup.py
2
setup.py
@ -148,6 +148,7 @@ _deps = [
|
||||
"pyyaml>=5.1",
|
||||
"pydantic",
|
||||
"pytest>=7.2.0,<8.0.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.9.0",
|
||||
@ -319,6 +320,7 @@ extras["tiktoken"] = deps_list("tiktoken", "blobfile")
|
||||
extras["testing"] = (
|
||||
deps_list(
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pytest-rich",
|
||||
"pytest-xdist",
|
||||
"timeout-decorator",
|
||||
|
@ -122,6 +122,7 @@ _import_structure = {
|
||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||
"file_utils": [],
|
||||
"generation": [
|
||||
"AsyncTextIteratorStreamer",
|
||||
"CompileConfig",
|
||||
"GenerationConfig",
|
||||
"TextIteratorStreamer",
|
||||
@ -5055,7 +5056,14 @@ if TYPE_CHECKING:
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
|
||||
# Generation
|
||||
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
||||
from .generation import (
|
||||
AsyncTextIteratorStreamer,
|
||||
CompileConfig,
|
||||
GenerationConfig,
|
||||
TextIteratorStreamer,
|
||||
TextStreamer,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
|
@ -54,6 +54,7 @@ deps = {
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic",
|
||||
"pytest": "pytest>=7.2.0,<8.0.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.9.0",
|
||||
|
@ -26,7 +26,7 @@ _import_structure = {
|
||||
"SynthIDTextWatermarkingConfig",
|
||||
"WatermarkingConfig",
|
||||
],
|
||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||
"streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -199,7 +199,7 @@ if TYPE_CHECKING:
|
||||
SynthIDTextWatermarkingConfig,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from .streamers import TextIteratorStreamer, TextStreamer
|
||||
from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@ -225,3 +226,91 @@ class TextIteratorStreamer(TextStreamer):
|
||||
raise StopIteration()
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class AsyncTextIteratorStreamer(TextStreamer):
|
||||
"""
|
||||
Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
|
||||
This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an
|
||||
interactive Gradio demo).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The API for the streamer classes is still under development and may change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
tokenizer (`AutoTokenizer`):
|
||||
The tokenized used to decode the tokens.
|
||||
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
||||
in `.generate()`, when it is called in a separate thread.
|
||||
decode_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If token generation time exceeds timeout value.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
|
||||
>>> from threading import Thread
|
||||
>>> import asyncio
|
||||
|
||||
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||
|
||||
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
||||
>>> async def main():
|
||||
... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
|
||||
... streamer = AsyncTextIteratorStreamer(tok)
|
||||
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
||||
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
... thread.start()
|
||||
... generated_text = ""
|
||||
... async for new_text in streamer:
|
||||
... generated_text += new_text
|
||||
>>> print(generated_text)
|
||||
>>> asyncio.run(main())
|
||||
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs
|
||||
):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.text_queue = asyncio.Queue()
|
||||
self.stop_signal = None
|
||||
self.timeout = timeout
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||||
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
|
||||
if stream_end:
|
||||
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
if self.has_asyncio_timeout:
|
||||
async with asyncio.timeout(self.timeout):
|
||||
value = await self.text_queue.get()
|
||||
else:
|
||||
value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError()
|
||||
else:
|
||||
if value == self.stop_signal:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return value
|
||||
|
@ -17,7 +17,15 @@ import unittest
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
|
||||
from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AsyncTextIteratorStreamer,
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer,
|
||||
TextStreamer,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
@ -120,3 +128,43 @@ class StreamerTester(unittest.TestCase):
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
|
||||
@require_torch
|
||||
@pytest.mark.asyncio(loop_scope="class")
|
||||
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_async_iterator_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
async def test_async_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
|
||||
with self.assertRaises(TimeoutError):
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
Loading…
Reference in New Issue
Block a user