mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 02:58:23 +06:00
Add implementation of typical sampling (#15504)
* typical decoding * changing arg name * add test config params * forgotten arg rename * fix edge case where scores are same * test for typical logits warper * code quality fixes
This commit is contained in:
parent
f588cf4050
commit
0113aae5b7
@ -282,6 +282,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
self.temperature = kwargs.pop("temperature", 1.0)
|
self.temperature = kwargs.pop("temperature", 1.0)
|
||||||
self.top_k = kwargs.pop("top_k", 50)
|
self.top_k = kwargs.pop("top_k", 50)
|
||||||
self.top_p = kwargs.pop("top_p", 1.0)
|
self.top_p = kwargs.pop("top_p", 1.0)
|
||||||
|
self.typical_p = kwargs.pop("typical_p", 1.0)
|
||||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
|
@ -239,6 +239,39 @@ class TopKLogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TypicalLogitsWarper(LogitsWarper):
|
||||||
|
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.mass = mass
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# calculate entropy
|
||||||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
|
p = torch.exp(normalized)
|
||||||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
|
# shift and sort
|
||||||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative mass above the threshold
|
||||||
|
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||||
|
last_ind[last_ind < 0] = 0
|
||||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
||||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||||
for idx in range(num_hypos):
|
for idx in range(num_hypos):
|
||||||
|
@ -40,6 +40,7 @@ from .generation_logits_process import (
|
|||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
)
|
)
|
||||||
from .generation_stopping_criteria import (
|
from .generation_stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
@ -620,7 +621,12 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
self,
|
||||||
|
top_k: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
typical_p: float = None,
|
||||||
|
temperature: float = None,
|
||||||
|
num_beams: int = None,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||||
@ -630,6 +636,7 @@ class GenerationMixin:
|
|||||||
# init warp parameters
|
# init warp parameters
|
||||||
top_k = top_k if top_k is not None else self.config.top_k
|
top_k = top_k if top_k is not None else self.config.top_k
|
||||||
top_p = top_p if top_p is not None else self.config.top_p
|
top_p = top_p if top_p is not None else self.config.top_p
|
||||||
|
typical_p = typical_p if typical_p is not None else self.config.typical_p
|
||||||
temperature = temperature if temperature is not None else self.config.temperature
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
# instantiate warpers list
|
# instantiate warpers list
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
@ -642,6 +649,8 @@ class GenerationMixin:
|
|||||||
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||||
|
if typical_p is not None and typical_p < 1.0:
|
||||||
|
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
@ -811,6 +820,7 @@ class GenerationMixin:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
bad_words_ids: Optional[Iterable[int]] = None,
|
bad_words_ids: Optional[Iterable[int]] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
@ -1191,7 +1201,7 @@ class GenerationMixin:
|
|||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# 10. prepare logits warper
|
# 10. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(
|
logits_warper = self._get_logits_warper(
|
||||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
@ -1253,7 +1263,7 @@ class GenerationMixin:
|
|||||||
elif is_beam_sample_gen_mode:
|
elif is_beam_sample_gen_mode:
|
||||||
# 10. prepare logits warper
|
# 10. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(
|
logits_warper = self._get_logits_warper(
|
||||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||||
)
|
)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
|
@ -58,6 +58,7 @@ config_common_kwargs = {
|
|||||||
"temperature": 2.0,
|
"temperature": 2.0,
|
||||||
"top_k": 10,
|
"top_k": 10,
|
||||||
"top_p": 0.7,
|
"top_p": 0.7,
|
||||||
|
"typical_p": 0.2,
|
||||||
"repetition_penalty": 0.8,
|
"repetition_penalty": 0.8,
|
||||||
"length_penalty": 0.8,
|
"length_penalty": 0.8,
|
||||||
"no_repeat_ngram_size": 5,
|
"no_repeat_ngram_size": 5,
|
||||||
|
@ -41,6 +41,7 @@ if is_torch_available():
|
|||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +192,51 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||||
|
|
||||||
|
def test_typical_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||||
|
dist = torch.log(
|
||||||
|
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
|
||||||
|
)
|
||||||
|
|
||||||
|
typical_warp = TypicalLogitsWarper(0.5)
|
||||||
|
filtered_dist = torch.exp(typical_warp(input_ids, dist))
|
||||||
|
|
||||||
|
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||||
|
# exp (-inf) => 0
|
||||||
|
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||||
|
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||||
|
|
||||||
|
# check special cases
|
||||||
|
length = 5
|
||||||
|
|
||||||
|
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||||
|
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
|
||||||
|
|
||||||
|
scores = typical_warp_safety_check(input_ids, logits)
|
||||||
|
# uniform dist is not changed
|
||||||
|
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
|
||||||
|
|
||||||
|
# check edge cases with negative and extreme logits
|
||||||
|
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||||
|
batch_size, 1
|
||||||
|
) - (vocab_size // 2)
|
||||||
|
|
||||||
|
# make ramp_logits more extreme
|
||||||
|
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||||
|
|
||||||
|
# make sure at least 2 tokens are kept
|
||||||
|
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
|
filtered_dist = typical_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
|
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||||
|
|
||||||
def test_no_repeat_ngram_dist_processor(self):
|
def test_no_repeat_ngram_dist_processor(self):
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
Loading…
Reference in New Issue
Block a user