mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[DOCS] add example NoBadWordsLogitsProcessor (#25046)
* add example NoBadWordsLogitsProcessor * fix L764 & L767 * make style
This commit is contained in:
parent
dcb183f4bd
commit
b99f7bd4fc
@ -747,6 +747,50 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
||||
List of list of token ids that are not allowed to be generated.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
|
||||
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
||||
In a word, the cake is a bit of a mess.
|
||||
|
||||
>>> # Now let's control generation taking the bad words out. Please note that the tokenizer is initialized differently
|
||||
|
||||
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
|
||||
|
||||
|
||||
>>> def get_tokens_as_list(word_list):
|
||||
... "Converts a sequence of words into a list of tokens"
|
||||
... tokens_list = []
|
||||
... for word in word_list.split(" "):
|
||||
... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
|
||||
... tokens_list.append(tokenized_word)
|
||||
... return tokens_list
|
||||
|
||||
|
||||
>>> word_list = "mess"
|
||||
>>> bad_words_ids = get_tokens_as_list(word_list=word_list)
|
||||
|
||||
>>> badwords_ids = model.generate(
|
||||
... inputs["input_ids"],
|
||||
... max_new_tokens=5,
|
||||
... bad_words_ids=bad_words_ids,
|
||||
... eos_token_id=tokenizer_with_prefix_space.eos_token_id,
|
||||
... )
|
||||
>>> print(tokenizer.batch_decode(badwords_ids, skip_special_tokens=True)[0])
|
||||
In a word, the cake is a bit of a surprise.
|
||||
|
||||
>>> badwords_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=5, bad_words_ids=bad_words_ids)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
In a word, the cake is a great way to start
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||
|
Loading…
Reference in New Issue
Block a user