From 0113aae5b7a3e9de7f6300c71ca593a5fdc3b0c2 Mon Sep 17 00:00:00 2001 From: Clara Meister Date: Wed, 9 Feb 2022 16:48:41 +0100 Subject: [PATCH] Add implementation of typical sampling (#15504) * typical decoding * changing arg name * add test config params * forgotten arg rename * fix edge case where scores are same * test for typical logits warper * code quality fixes --- src/transformers/configuration_utils.py | 1 + src/transformers/generation_logits_process.py | 33 +++++++++++++ src/transformers/generation_utils.py | 16 +++++-- tests/test_configuration_common.py | 1 + tests/test_generation_logits_process.py | 46 +++++++++++++++++++ 5 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 951645ef0fa..3a43677a704 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -282,6 +282,7 @@ class PretrainedConfig(PushToHubMixin): self.temperature = kwargs.pop("temperature", 1.0) self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) + self.typical_p = kwargs.pop("typical_p", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 3e8a2f3cff9..ad79273502e 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -239,6 +239,39 @@ class TopKLogitsWarper(LogitsWarper): return scores +class TypicalLogitsWarper(LogitsWarper): + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 537c7cdecc3..895ac031141 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -40,6 +40,7 @@ from .generation_logits_process import ( TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, ) from .generation_stopping_criteria import ( MaxLengthCriteria, @@ -620,7 +621,12 @@ class GenerationMixin: ) def _get_logits_warper( - self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None + self, + top_k: int = None, + top_p: float = None, + typical_p: float = None, + temperature: float = None, + num_beams: int = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -630,6 +636,7 @@ class GenerationMixin: # init warp parameters top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p + typical_p = typical_p if typical_p is not None else self.config.typical_p temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() @@ -642,6 +649,8 @@ class GenerationMixin: warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) return warpers def _get_logits_processor( @@ -811,6 +820,7 @@ class GenerationMixin: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, @@ -1191,7 +1201,7 @@ class GenerationMixin: elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams ) # 11. expand input_ids with `num_return_sequences` additional sequences per batch @@ -1253,7 +1263,7 @@ class GenerationMixin: elif is_beam_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams ) if stopping_criteria.max_length is None: diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 034ad564694..2b4a023d91c 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -58,6 +58,7 @@ config_common_kwargs = { "temperature": 2.0, "top_k": 10, "top_p": 0.7, + "typical_p": 0.2, "repetition_penalty": 0.8, "length_penalty": 0.8, "no_repeat_ngram_size": 5, diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index e07fd3066e2..7b6c78b286e 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -41,6 +41,7 @@ if is_torch_available(): TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, ) @@ -191,6 +192,51 @@ class LogitsProcessorTest(unittest.TestCase): # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) + def test_typical_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = torch.log( + torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float) + ) + + typical_warp = TypicalLogitsWarper(0.5) + filtered_dist = torch.exp(typical_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= 0.7 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = torch.tensor( + [[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float + ) + self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check special cases + length = 5 + + logits = self._get_uniform_logits(batch_size=batch_size, length=length) + typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3) + + scores = typical_warp_safety_check(input_ids, logits) + # uniform dist is not changed + self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0]) + + # check edge cases with negative and extreme logits + ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( + batch_size, 1 + ) - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = typical_warp(input_ids, ramp_logits) + + # first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) + def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2