mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add implementation of typical sampling (#15504)
* typical decoding * changing arg name * add test config params * forgotten arg rename * fix edge case where scores are same * test for typical logits warper * code quality fixes
This commit is contained in:
parent
f588cf4050
commit
0113aae5b7
@ -282,6 +282,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
self.typical_p = kwargs.pop("typical_p", 1.0)
|
||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
|
@ -239,6 +239,39 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
return scores
|
||||
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
|
||||
self.filter_value = filter_value
|
||||
self.mass = mass
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
|
@ -40,6 +40,7 @@ from .generation_logits_process import (
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
from .generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
@ -620,7 +621,12 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||
self,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
typical_p: float = None,
|
||||
temperature: float = None,
|
||||
num_beams: int = None,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
@ -630,6 +636,7 @@ class GenerationMixin:
|
||||
# init warp parameters
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
typical_p = typical_p if typical_p is not None else self.config.typical_p
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
# instantiate warpers list
|
||||
warpers = LogitsProcessorList()
|
||||
@ -642,6 +649,8 @@ class GenerationMixin:
|
||||
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
@ -811,6 +820,7 @@ class GenerationMixin:
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
bad_words_ids: Optional[Iterable[int]] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
@ -1191,7 +1201,7 @@ class GenerationMixin:
|
||||
elif is_sample_gen_mode:
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||
)
|
||||
|
||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
@ -1253,7 +1263,7 @@ class GenerationMixin:
|
||||
elif is_beam_sample_gen_mode:
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||
)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
|
@ -58,6 +58,7 @@ config_common_kwargs = {
|
||||
"temperature": 2.0,
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"typical_p": 0.2,
|
||||
"repetition_penalty": 0.8,
|
||||
"length_penalty": 0.8,
|
||||
"no_repeat_ngram_size": 5,
|
||||
|
@ -41,6 +41,7 @@ if is_torch_available():
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
@ -191,6 +192,51 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||
|
||||
def test_typical_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
typical_warp = TypicalLogitsWarper(0.5)
|
||||
filtered_dist = torch.exp(typical_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check special cases
|
||||
length = 5
|
||||
|
||||
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
scores = typical_warp_safety_check(input_ids, logits)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||
batch_size, 1
|
||||
) - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = typical_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
Loading…
Reference in New Issue
Block a user