mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add descriptive docstring to TemperatureLogitsWarper (#24892)
* Add descriptive docstring to TemperatureLogitsWarper It addresses https://github.com/huggingface/transformers/issues/24783 * Remove niche features Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Commit suggestion Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Refactor the examples to simpler ones * Add a missing comma Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Make args description more compact Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove extra text after making description more compact Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Fix linter --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
31acba5697
commit
04a5c859b0
@ -172,11 +172,58 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class TemperatureLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
|
||||
that it can control the randomness of the predicted tokens.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
|
||||
any effect.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
|
||||
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
|
||||
token.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> model.config.pad_token_id = model.config.eos_token_id
|
||||
>>> model.generation_config.pad_token_id = model.config.eos_token_id
|
||||
>>> input_context = "Hugging Face Company is"
|
||||
>>> input_ids = tokenizer.encode(input_context, return_tensors="pt")
|
||||
|
||||
>>> torch.manual_seed(0)
|
||||
|
||||
>>> # With temperature=1, the default, we consistently get random outputs due to random sampling.
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Hugging Face Company is one of these companies that is going to take a
|
||||
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Hugging Face Company is one of these companies, you can make a very
|
||||
|
||||
>>> # However, with temperature close to 0 , the output remains invariant.
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Hugging Face Company is a company that has been around for over 20 years
|
||||
|
||||
>>> # even if we set a different seed.
|
||||
>>> torch.manual_seed(42)
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Hugging Face Company is a company that has been around for over 20 years
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
|
Loading…
Reference in New Issue
Block a user