mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Implement AsyncTextIteratorStreamer for asynchronous streaming (#34931)
* Add AsyncTextIteratorStreamer class * export AsyncTextIteratorStreamer * export AsyncTextIteratorStreamer * improve docs * missing import * missing import * doc example fix * doc example output fix * add pytest-asyncio * first attempt at tests * missing import * add pytest-asyncio * fallback to wait_for and raise TimeoutError on timeout * check for TimeoutError * autodoc * reorder imports * fix style --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
b5a557e5fe
commit
eafbb0eca7
@ -352,6 +352,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
|
|
||||||
[[autodoc]] TextIteratorStreamer
|
[[autodoc]] TextIteratorStreamer
|
||||||
|
|
||||||
|
[[autodoc]] AsyncTextIteratorStreamer
|
||||||
|
|
||||||
## Caches
|
## Caches
|
||||||
|
|
||||||
[[autodoc]] Cache
|
[[autodoc]] Cache
|
||||||
|
2
setup.py
2
setup.py
@ -148,6 +148,7 @@ _deps = [
|
|||||||
"pyyaml>=5.1",
|
"pyyaml>=5.1",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"pytest>=7.2.0,<8.0.0",
|
"pytest>=7.2.0,<8.0.0",
|
||||||
|
"pytest-asyncio",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
"python>=3.9.0",
|
"python>=3.9.0",
|
||||||
@ -319,6 +320,7 @@ extras["tiktoken"] = deps_list("tiktoken", "blobfile")
|
|||||||
extras["testing"] = (
|
extras["testing"] = (
|
||||||
deps_list(
|
deps_list(
|
||||||
"pytest",
|
"pytest",
|
||||||
|
"pytest-asyncio",
|
||||||
"pytest-rich",
|
"pytest-rich",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
|
@ -122,6 +122,7 @@ _import_structure = {
|
|||||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||||
"file_utils": [],
|
"file_utils": [],
|
||||||
"generation": [
|
"generation": [
|
||||||
|
"AsyncTextIteratorStreamer",
|
||||||
"CompileConfig",
|
"CompileConfig",
|
||||||
"GenerationConfig",
|
"GenerationConfig",
|
||||||
"TextIteratorStreamer",
|
"TextIteratorStreamer",
|
||||||
@ -5055,7 +5056,14 @@ if TYPE_CHECKING:
|
|||||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||||
|
|
||||||
# Generation
|
# Generation
|
||||||
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
from .generation import (
|
||||||
|
AsyncTextIteratorStreamer,
|
||||||
|
CompileConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
TextIteratorStreamer,
|
||||||
|
TextStreamer,
|
||||||
|
WatermarkingConfig,
|
||||||
|
)
|
||||||
from .hf_argparser import HfArgumentParser
|
from .hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
# Integrations
|
# Integrations
|
||||||
|
@ -54,6 +54,7 @@ deps = {
|
|||||||
"pyyaml": "pyyaml>=5.1",
|
"pyyaml": "pyyaml>=5.1",
|
||||||
"pydantic": "pydantic",
|
"pydantic": "pydantic",
|
||||||
"pytest": "pytest>=7.2.0,<8.0.0",
|
"pytest": "pytest>=7.2.0,<8.0.0",
|
||||||
|
"pytest-asyncio": "pytest-asyncio",
|
||||||
"pytest-timeout": "pytest-timeout",
|
"pytest-timeout": "pytest-timeout",
|
||||||
"pytest-xdist": "pytest-xdist",
|
"pytest-xdist": "pytest-xdist",
|
||||||
"python": "python>=3.9.0",
|
"python": "python>=3.9.0",
|
||||||
|
@ -26,7 +26,7 @@ _import_structure = {
|
|||||||
"SynthIDTextWatermarkingConfig",
|
"SynthIDTextWatermarkingConfig",
|
||||||
"WatermarkingConfig",
|
"WatermarkingConfig",
|
||||||
],
|
],
|
||||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
"streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -199,7 +199,7 @@ if TYPE_CHECKING:
|
|||||||
SynthIDTextWatermarkingConfig,
|
SynthIDTextWatermarkingConfig,
|
||||||
WatermarkingConfig,
|
WatermarkingConfig,
|
||||||
)
|
)
|
||||||
from .streamers import TextIteratorStreamer, TextStreamer
|
from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
@ -225,3 +226,91 @@ class TextIteratorStreamer(TextStreamer):
|
|||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTextIteratorStreamer(TextStreamer):
|
||||||
|
"""
|
||||||
|
Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
|
||||||
|
This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an
|
||||||
|
interactive Gradio demo).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The API for the streamer classes is still under development and may change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tokenizer (`AutoTokenizer`):
|
||||||
|
The tokenized used to decode the tokens.
|
||||||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||||
|
timeout (`float`, *optional*):
|
||||||
|
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
||||||
|
in `.generate()`, when it is called in a separate thread.
|
||||||
|
decode_kwargs (`dict`, *optional*):
|
||||||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If token generation time exceeds timeout value.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
|
||||||
|
>>> from threading import Thread
|
||||||
|
>>> import asyncio
|
||||||
|
|
||||||
|
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||||
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
||||||
|
>>> async def main():
|
||||||
|
... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
|
||||||
|
... streamer = AsyncTextIteratorStreamer(tok)
|
||||||
|
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
||||||
|
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
... thread.start()
|
||||||
|
... generated_text = ""
|
||||||
|
... async for new_text in streamer:
|
||||||
|
... generated_text += new_text
|
||||||
|
>>> print(generated_text)
|
||||||
|
>>> asyncio.run(main())
|
||||||
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs
|
||||||
|
):
|
||||||
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||||
|
self.text_queue = asyncio.Queue()
|
||||||
|
self.stop_signal = None
|
||||||
|
self.timeout = timeout
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
|
||||||
|
|
||||||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
|
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||||||
|
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
|
||||||
|
if stream_end:
|
||||||
|
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
if self.has_asyncio_timeout:
|
||||||
|
async with asyncio.timeout(self.timeout):
|
||||||
|
value = await self.text_queue.get()
|
||||||
|
else:
|
||||||
|
value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TimeoutError()
|
||||||
|
else:
|
||||||
|
if value == self.stop_signal:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
@ -17,7 +17,15 @@ import unittest
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available
|
import pytest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AsyncTextIteratorStreamer,
|
||||||
|
AutoTokenizer,
|
||||||
|
TextIteratorStreamer,
|
||||||
|
TextStreamer,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||||
|
|
||||||
from ..test_modeling_common import ids_tensor
|
from ..test_modeling_common import ids_tensor
|
||||||
@ -120,3 +128,43 @@ class StreamerTester(unittest.TestCase):
|
|||||||
streamer_text = ""
|
streamer_text = ""
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
streamer_text += new_text
|
streamer_text += new_text
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@pytest.mark.asyncio(loop_scope="class")
|
||||||
|
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_async_iterator_streamer_matches_non_streaming(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
model.config.eos_token_id = -1
|
||||||
|
|
||||||
|
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||||
|
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||||
|
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||||
|
|
||||||
|
streamer = AsyncTextIteratorStreamer(tokenizer)
|
||||||
|
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
streamer_text = ""
|
||||||
|
async for new_text in streamer:
|
||||||
|
streamer_text += new_text
|
||||||
|
|
||||||
|
self.assertEqual(streamer_text, greedy_text)
|
||||||
|
|
||||||
|
async def test_async_iterator_streamer_timeout(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
model.config.eos_token_id = -1
|
||||||
|
|
||||||
|
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||||
|
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
|
||||||
|
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
|
||||||
|
with self.assertRaises(TimeoutError):
|
||||||
|
streamer_text = ""
|
||||||
|
async for new_text in streamer:
|
||||||
|
streamer_text += new_text
|
||||||
|
Loading…
Reference in New Issue
Block a user