[DOCS] add example NoBadWordsLogitsProcessor (#25046)

* add example NoBadWordsLogitsProcessor

* fix L764 & L767

* make style
This commit is contained in:
Gema Parreño 2023-07-25 15:41:48 +02:00 committed by GitHub
parent dcb183f4bd
commit b99f7bd4fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -747,6 +747,50 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a mess.
>>> # Now let's control generation taking the bad words out. Please note that the tokenizer is initialized differently
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
>>> def get_tokens_as_list(word_list):
... "Converts a sequence of words into a list of tokens"
... tokens_list = []
... for word in word_list.split(" "):
... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
... tokens_list.append(tokenized_word)
... return tokens_list
>>> word_list = "mess"
>>> bad_words_ids = get_tokens_as_list(word_list=word_list)
>>> badwords_ids = model.generate(
... inputs["input_ids"],
... max_new_tokens=5,
... bad_words_ids=bad_words_ids,
... eos_token_id=tokenizer_with_prefix_space.eos_token_id,
... )
>>> print(tokenizer.batch_decode(badwords_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a surprise.
>>> badwords_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=5, bad_words_ids=bad_words_ids)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
In a word, the cake is a great way to start
```
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):