Merge pull request #2303 from patrickvonplaten/fix_error_with_repetition_penalty

fix repetition penalty error in modeling_utils.py
This commit is contained in:
Thomas Wolf 2019-12-25 22:39:20 +01:00 committed by GitHub
commit aeef4823ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -728,6 +728,10 @@ class PreTrainedModel(nn.Module):
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_tokens] < 0:
next_token_logits[i, previous_tokens] *= repetition_penalty
else:
next_token_logits[i, previous_tokens] /= repetition_penalty
if do_sample:
@ -807,6 +811,10 @@ class PreTrainedModel(nn.Module):
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_tokens in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_tokens] < 0:
scores[i, previous_tokens] *= repetition_penalty
else:
scores[i, previous_tokens] /= repetition_penalty
if do_sample: