mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix overflowing bad word ids (#10889)
* Removes overflowing bad word IDs * Raise warning
This commit is contained in:
parent
1f5ea9e04a
commit
3c12e3c1c4
@ -22,6 +22,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||||
@ -417,7 +421,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
banned_mask_list = []
|
banned_mask_list = []
|
||||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||||
for token in batch_banned_tokens:
|
for token in batch_banned_tokens:
|
||||||
banned_mask_list.append([idx, token])
|
# Eliminates invalid bad word IDs that are over the vocabulary size.
|
||||||
|
if token <= scores.shape[1]:
|
||||||
|
banned_mask_list.append([idx, token])
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
|
||||||
|
f"vocabulary, and is therefore ignored."
|
||||||
|
)
|
||||||
if not banned_mask_list:
|
if not banned_mask_list:
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user