mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix doctest
This commit is contained in:
parent
98a2db8311
commit
4772768457
@ -258,8 +258,7 @@ class TemperatureLogitsWarper(LogitsWarper):
|
||||
>>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
|
||||
>>> outputs = model.generate(**inputs, **generate_kwargs)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
['Hugging Face Company is a joint venture between GEO Group, one of',
|
||||
'Hugging Face Company is not an exact science – but what we believe does']
|
||||
['Hugging Face Company is one of these companies that is going to take a', "Hugging Face Company is a brand created by Brian A. O'Neil"]
|
||||
|
||||
>>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
|
||||
>>> generate_kwargs["temperature"] = 0.0001
|
||||
@ -425,7 +424,9 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
|
||||
>>> outputs = model.generate(**inputs, do_sample=True)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
|
||||
A sequence: 1, 2, 3, 4, 5, 6
|
||||
<BLANKLINE>
|
||||
The data are a
|
||||
|
||||
>>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
|
||||
>>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
|
||||
@ -489,13 +490,15 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
|
||||
>>> outputs = model.generate(**inputs, do_sample=True)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
A sequence: A, B, C, D, G, H, I. A, M
|
||||
A sequence: A, B, C, D, E, N, O, P, P
|
||||
|
||||
>>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
|
||||
>>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
|
||||
>>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
A sequence: A, B, C, D, E, F, G, H, I
|
||||
A sequence: A, B, C, D, E.
|
||||
<BLANKLINE>
|
||||
The sequence is a sequence
|
||||
```
|
||||
"""
|
||||
|
||||
@ -630,7 +633,9 @@ class EpsilonLogitsWarper(LogitsWarper):
|
||||
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
|
||||
>>> outputs = model.generate(**inputs, do_sample=True)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
|
||||
A sequence: 1, 2, 3, 4, 5, 6
|
||||
<BLANKLINE>
|
||||
The data are a
|
||||
|
||||
>>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
|
||||
>>> # Top P sampling, which restricts tokens based on their cumulative probability.
|
||||
@ -707,7 +712,9 @@ class EtaLogitsWarper(LogitsWarper):
|
||||
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
|
||||
>>> outputs = model.generate(**inputs, do_sample=True)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
|
||||
A sequence: 1, 2, 3, 4, 5, 6
|
||||
<BLANKLINE>
|
||||
The data are a
|
||||
|
||||
>>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
|
||||
>>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
|
||||
@ -1215,10 +1222,10 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
... In this case, `batch_id` is not used, but you can set rules for each batch member.
|
||||
... '''
|
||||
... if input_ids[-1] == entity[0]:
|
||||
... return entity[0:1].tolist()
|
||||
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
|
||||
... return entity[1:2].tolist()
|
||||
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
|
||||
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
|
||||
... return entity[2:3].tolist()
|
||||
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
|
||||
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
|
||||
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
|
||||
@ -1616,7 +1623,7 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
|
||||
>>> # distribution, summing to 1
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
|
||||
tensor(816.3250)
|
||||
tensor(816.2924)
|
||||
|
||||
>>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
|
||||
>>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
|
||||
@ -1652,7 +1659,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
|
||||
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
|
||||
>>> print(outputs.scores[0][0, 50256]) # 1 (and not 0) is the first freely generated token
|
||||
tensor(-inf)
|
||||
>>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
|
||||
tensor(29.9010)
|
||||
@ -1662,7 +1669,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
|
||||
... )
|
||||
>>> print(outputs.scores[1][0, 50256])
|
||||
tensor(11.2027)
|
||||
tensor(8.1599)
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1708,9 +1715,10 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
||||
tensor(-inf)
|
||||
|
||||
>>> # If we disable `suppress_tokens`, we can generate it.
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
|
||||
>>> model.generation_config.suppress_tokens = None
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
>>> print(outputs.scores[1][0, 1])
|
||||
tensor(5.7738)
|
||||
tensor(6.0678)
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1743,24 +1751,25 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
|
||||
>>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
|
||||
>>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
>>> # We can use an example in Whisper model (note: forced_decoder_ids is deprecated from v4.39)
|
||||
>>> # `"forced_decoder_ids": [[6, 50362]]` means all other tokens are masked out at position 6 except for 50362.
|
||||
>>> # Because `output.scores` keeps track of generated tokens only, and not the input ids, we check `scores[4]`
|
||||
>>> outputs = model.generate(**inputs, forced_decoder_ids=[[6, 50362]], return_dict_in_generate=True, output_scores=True)
|
||||
>>> print(
|
||||
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
|
||||
... all(outputs.scores[4][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
|
||||
... )
|
||||
True
|
||||
>>> print(outputs.scores[0][0, 50362])
|
||||
>>> print(outputs.scores[4][0, 50362])
|
||||
tensor(0.)
|
||||
|
||||
>>> # If we disable `forced_decoder_ids`, we stop seeing that effect
|
||||
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
|
||||
>>> print(
|
||||
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
|
||||
... all(outputs.scores[4][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
|
||||
... )
|
||||
False
|
||||
>>> print(outputs.scores[0][0, 50362])
|
||||
tensor(19.3140)
|
||||
>>> print(outputs.scores[4][0, 50362])
|
||||
tensor(-2.4375)
|
||||
```
|
||||
"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user