mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[generate] update docstring of SequenceBiasLogitsProcessor
(#35699)
* fix docstring * space
This commit is contained in:
parent
56afd2f488
commit
88b95e6179
@ -1040,10 +1040,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
<Tip>
|
||||
|
||||
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
|
||||
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
|
||||
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
|
||||
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||
At a token-level, biasing a word is different from biasing a word with a space before it. If you want to bias
|
||||
"foo" mid-sentence, you'll likely want to add a prefix space and bias " foo" instead. Check the tokenizer section
|
||||
of our NLP course to find out why: https://huggingface.co/learn/nlp-course/chapter2/4?fw=pt
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -1060,37 +1059,40 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
|
||||
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4, do_sample=False)
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald J. Trump Jr
|
||||
|
||||
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
|
||||
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
|
||||
|
||||
The full name of Donald is Donald John Trump Sr.
|
||||
|
||||
>>> def get_tokens(word):
|
||||
... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
|
||||
... return tokenizer([word], add_special_tokens=False).input_ids[0]
|
||||
|
||||
|
||||
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
|
||||
>>> sequence_bias = [get_tokens("Trump"), -10.0]
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
|
||||
>>> # IMPORTANT: Remember our tip about adding spaces before words to bias them correctly.
|
||||
>>> sequence_bias = [[get_tokens("Trump"), -10.0],] # will fail to apply bias
|
||||
>>> biased_ids = model.generate(
|
||||
... inputs["input_ids"], max_new_tokens=4, do_sample=False, sequence_bias=sequence_bias
|
||||
... )
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald J. Donald,
|
||||
The full name of Donald is Donald John Trump Sr.
|
||||
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||
>>> sequence_bias = [[get_tokens(" Trump"), -10.0],] # will work
|
||||
>>> biased_ids = model.generate(
|
||||
... inputs["input_ids"], max_new_tokens=4, do_sample=False, sequence_bias=sequence_bias
|
||||
... )
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald Rumsfeld,
|
||||
The full name of Donald is Donald John Harper. He
|
||||
|
||||
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
|
||||
>>> sequence_bias = [get_tokens("Donald Duck"), 10.0]
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations. This technique
|
||||
>>> # is also more effective when paired up with beam search.
|
||||
>>> sequence_bias = [[get_tokens(" Donald Duck"), 10.0],]
|
||||
>>> biased_ids = model.generate(
|
||||
... inputs["input_ids"], max_new_tokens=4, num_beams=4, do_sample=False, sequence_bias=sequence_bias
|
||||
... )
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald Duck.
|
||||
The full name of Donald is Donald Duck. He is
|
||||
```
|
||||
"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user