Add implementation of typical sampling (#15504)

* typical decoding

* changing arg name

* add test config params

* forgotten arg rename

* fix edge case where scores are same

* test for typical logits warper

* code quality fixes
This commit is contained in:
Clara Meister 2022-02-09 16:48:41 +01:00 committed by GitHub
parent f588cf4050
commit 0113aae5b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 3 deletions

View File

@ -282,6 +282,7 @@ class PretrainedConfig(PushToHubMixin):
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)

View File

@ -239,6 +239,39 @@ class TopKLogitsWarper(LogitsWarper):
return scores
class TypicalLogitsWarper(LogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):

View File

@ -40,6 +40,7 @@ from .generation_logits_process import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
@ -620,7 +621,12 @@ class GenerationMixin:
)
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
self,
top_k: int = None,
top_p: float = None,
typical_p: float = None,
temperature: float = None,
num_beams: int = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@ -630,6 +636,7 @@ class GenerationMixin:
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
typical_p = typical_p if typical_p is not None else self.config.typical_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
@ -642,6 +649,8 @@ class GenerationMixin:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
return warpers
def _get_logits_processor(
@ -811,6 +820,7 @@ class GenerationMixin:
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
@ -1191,7 +1201,7 @@ class GenerationMixin:
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
@ -1253,7 +1263,7 @@ class GenerationMixin:
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
)
if stopping_criteria.max_length is None:

View File

@ -58,6 +58,7 @@ config_common_kwargs = {
"temperature": 2.0,
"top_k": 10,
"top_p": 0.7,
"typical_p": 0.2,
"repetition_penalty": 0.8,
"length_penalty": 0.8,
"no_repeat_ngram_size": 5,

View File

@ -41,6 +41,7 @@ if is_torch_available():
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
@ -191,6 +192,51 @@ class LogitsProcessorTest(unittest.TestCase):
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_typical_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
)
typical_warp = TypicalLogitsWarper(0.5)
filtered_dist = torch.exp(typical_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
scores = typical_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = typical_warp(input_ids, ramp_logits)
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2