From 18e5bdbec5b12ad395bfb2a30223c78d74a9c158 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 24 Dec 2019 17:18:05 +0100 Subject: [PATCH] fix repetition penalty error in modeling_utils.py --- src/transformers/modeling_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8413aad595d..2698816c664 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -728,7 +728,11 @@ class PreTrainedModel(nn.Module): if repetition_penalty != 1.0: for i in range(batch_size): for previous_tokens in set(input_ids[i].tolist()): - next_token_logits[i, previous_tokens] /= repetition_penalty + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if next_token_logits[i, previous_tokens] < 0: + next_token_logits[i, previous_tokens] *= repetition_penalty + else: + next_token_logits[i, previous_tokens] /= repetition_penalty if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -807,7 +811,11 @@ class PreTrainedModel(nn.Module): if repetition_penalty != 1.0: for i in range(batch_size * num_beams): for previous_tokens in set(input_ids[i].tolist()): - scores[i, previous_tokens] /= repetition_penalty + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if scores[i, previous_tokens] < 0: + scores[i, previous_tokens] *= repetition_penalty + else: + scores[i, previous_tokens] /= repetition_penalty if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens)