Update docstrings for text generation pipeline (#30343)

* Update docstrings for text generation pipeline

* Fix docstring arg

* Update docstring to explain chat mode

* Fix doctests

* Fix doctests
This commit is contained in:
Matt 2024-04-22 14:01:30 +01:00 committed by GitHub
parent 2d92db8458
commit 0e9d44d7a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,10 +37,11 @@ class Chat:
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
where each dict contains "role" and "content" keys.
specified text prompt. When the underlying model is a conversational model, it can also accept one or more chats,
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Example:
Examples:
```python
>>> from transformers import pipeline
@ -53,6 +54,15 @@ class TextGenerationPipeline(Pipeline):
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
```
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta")
>>> # Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string
>>> generator([{"role": "user", "content": "What is the capital of France? Answer in one word."}], do_sample=False, max_new_tokens=2)
[{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'Paris'}]}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
text generation parameters in [Text generation strategies](../generation_strategies) and [Text
@ -62,8 +72,9 @@ class TextGenerationPipeline(Pipeline):
`"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
objective. See the list of available [text completion models](https://huggingface.co/models?filter=text-generation)
and the list of [conversational models](https://huggingface.co/models?other=conversational)
on [huggingface.co/models].
"""
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
@ -194,8 +205,11 @@ class TextGenerationPipeline(Pipeline):
Complete the prompt(s) given as inputs.
Args:
text_inputs (`str` or `List[str]`):
One or several prompts (or one list of prompts) to complete.
text_inputs (`str`, `List[str]`, List[Dict[str, str]], or `List[List[Dict[str, str]]]`):
One or several prompts (or one list of prompts) to complete. If strings or a list of string are
passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list
of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed,
the model's chat template will be used to format them before passing them to the model.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
`True`, the decoded text is not returned.
@ -222,7 +236,7 @@ class TextGenerationPipeline(Pipeline):
corresponding to your framework [here](./model#generative-models)).
Return:
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination
of both `generated_text` and `generated_token_ids`):
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.