mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Tokenizer Utils Base] Make pad function more flexible (#9928)
* change tokenizer requirement * split line * Correct typo from list to str * improve style * make other function pretty as well * add comment * correct typo * add new test * pass tests for tok without padding token * Apply suggestions from code review
This commit is contained in:
parent
d1b14c9b54
commit
538b3b4607
@ -64,7 +64,7 @@ def get_tfds(
|
||||
label_name = features_name.pop(label_column_id)
|
||||
label_list = list(set(ds[list(files.keys())[0]][label_name]))
|
||||
label2id = {label: i for i, label in enumerate(label_list)}
|
||||
input_names = ["input_ids"] + tokenizer.model_input_names
|
||||
input_names = tokenizer.model_input_names
|
||||
transformed_ds = {}
|
||||
|
||||
if len(features_name) == 1:
|
||||
|
@ -98,7 +98,7 @@ if is_tf_available():
|
||||
label = d.pop("label")
|
||||
yield (d, label)
|
||||
|
||||
input_names = ["input_ids"] + tokenizer.model_input_names
|
||||
input_names = tokenizer.model_input_names
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
|
@ -97,7 +97,7 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -106,7 +106,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = BarthezTokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -92,7 +92,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
||||
},
|
||||
}
|
||||
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -100,7 +100,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -110,7 +110,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = CamembertTokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -68,4 +68,4 @@ class DistilBertTokenizer(BertTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
@ -77,5 +77,5 @@ class DistilBertTokenizerFast(BertTokenizerFast):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = DistilBertTokenizer
|
||||
|
@ -385,4 +385,4 @@ class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
|
||||
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
@ -387,5 +387,5 @@ class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
|
||||
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = DPRReaderTokenizer
|
||||
|
@ -177,7 +177,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -148,7 +148,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -116,7 +116,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = GPT2Tokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -92,7 +92,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
language_code_re = re.compile(">>.+<<") # type: re.Pattern
|
||||
|
||||
def __init__(
|
||||
|
@ -122,7 +122,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -102,7 +102,7 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = MPNetTokenizer
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -94,7 +94,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||
super().__init__(unk_token=unk_token, **kwargs)
|
||||
|
@ -61,7 +61,7 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = OpenAIGPTTokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, tokenizer_file=None, unk_token="<unk>", **kwargs):
|
||||
|
@ -84,7 +84,7 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -93,7 +93,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = PegasusTokenizer
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -84,7 +84,7 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>", additional_special_tokens=[], **kwargs):
|
||||
super().__init__(
|
||||
|
@ -93,7 +93,7 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = ReformerTokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -53,4 +53,4 @@ class RetriBertTokenizer(BertTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
@ -58,4 +58,4 @@ class RetriBertTokenizerFast(BertTokenizerFast):
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
slow_tokenizer_class = RetriBertTokenizer
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
@ -129,7 +129,7 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -138,7 +138,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = RobertaTokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -97,7 +97,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -108,7 +108,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = T5Tokenizer
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
|
@ -151,7 +151,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = []
|
||||
model_input_names = ["input_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -104,7 +104,7 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -102,7 +102,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -114,7 +114,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = XLMRobertaTokenizer
|
||||
|
||||
def __init__(
|
||||
|
@ -300,7 +300,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
all_answers = []
|
||||
for features, example in zip(features_list, examples):
|
||||
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
|
||||
model_input_names = self.tokenizer.model_input_names
|
||||
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
|
||||
|
||||
# Manage tensor allocation on correct device
|
||||
|
@ -1492,7 +1492,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
|
||||
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
|
||||
max_model_input_sizes: Dict[str, Optional[int]] = {}
|
||||
model_input_names: List[str] = ["token_type_ids", "attention_mask"]
|
||||
|
||||
# first name has to correspond to main model input name
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
|
||||
padding_side: str = "right"
|
||||
slow_tokenizer_class = None
|
||||
|
||||
@ -2633,13 +2636,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
||||
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
||||
|
||||
assert "input_ids" in encoded_inputs, (
|
||||
"You should supply an encoding or a list of encodings to this method. "
|
||||
"An encoding is the output of one the encoding methods of the tokenizer, i.e. "
|
||||
"__call__/encode_plus/batch_encode_plus. "
|
||||
)
|
||||
# The model's main input name, usually `input_ids`, has be passed for padding
|
||||
if self.model_input_names[0] not in encoded_inputs:
|
||||
raise ValueError(
|
||||
"You should supply an encoding or a list of encodings to this method"
|
||||
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
||||
)
|
||||
|
||||
if not encoded_inputs["input_ids"]:
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if not required_input:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = []
|
||||
return encoded_inputs
|
||||
@ -2648,14 +2654,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
# and rebuild them afterwards if no return_tensors is specified
|
||||
# Note that we lose the specific device the tensor may be on for PyTorch
|
||||
|
||||
first_element = encoded_inputs["input_ids"][0]
|
||||
first_element = required_input[0]
|
||||
if isinstance(first_element, (list, tuple)):
|
||||
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
||||
index = 0
|
||||
while len(encoded_inputs["input_ids"][index]) == 0:
|
||||
while len(required_input[index]) == 0:
|
||||
index += 1
|
||||
if index < len(encoded_inputs["input_ids"]):
|
||||
first_element = encoded_inputs["input_ids"][index][0]
|
||||
if index < len(required_input):
|
||||
first_element = required_input[index][0]
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
@ -2678,7 +2684,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
padding=padding, max_length=max_length, verbose=verbose
|
||||
)
|
||||
|
||||
if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
if required_input and not isinstance(required_input[0], (list, tuple)):
|
||||
encoded_inputs = self._pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
@ -2688,13 +2695,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
)
|
||||
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
||||
|
||||
batch_size = len(encoded_inputs["input_ids"])
|
||||
batch_size = len(required_input)
|
||||
assert all(
|
||||
len(v) == batch_size for v in encoded_inputs.values()
|
||||
), "Some items in the output dictionary have a different batch size than others."
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
|
||||
max_length = max(len(inputs) for inputs in required_input)
|
||||
padding_strategy = PaddingStrategy.MAX_LENGTH
|
||||
|
||||
batch_outputs = {}
|
||||
@ -3004,42 +3011,42 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(encoded_inputs["input_ids"])
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = (
|
||||
padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
|
||||
)
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(encoded_inputs["input_ids"])
|
||||
difference = max_length - len(required_input)
|
||||
if self.padding_side == "right":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
||||
)
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
||||
elif self.padding_side == "left":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -125,7 +125,7 @@ class {{cookiecutter.camelcase_modelname}}Tokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -270,7 +270,7 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(PreTrainedTokenizerFast)
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -160,6 +160,38 @@ class TokenizerTesterMixin:
|
||||
# "This is a test",
|
||||
# )
|
||||
|
||||
def assert_padded_input_match(self, input_r: list, input_p: list, max_length: int, pad_token_id: int):
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(
|
||||
self,
|
||||
input_r: dict,
|
||||
input_p: dict,
|
||||
max_length: int,
|
||||
pad_token_id: int,
|
||||
model_main_input_name: str = "input_ids",
|
||||
):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r[model_main_input_name], input_p[model_main_input_name]):
|
||||
self.assert_padded_input_match(i_r, i_p, max_length, pad_token_id)
|
||||
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
|
||||
@staticmethod
|
||||
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
|
||||
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
|
||||
@ -169,6 +201,18 @@ class TokenizerTesterMixin:
|
||||
for i in range(len(batch_encode_plus_sequences["input_ids"]))
|
||||
]
|
||||
|
||||
def test_model_input_names_signature(self):
|
||||
accepted_model_main_input_names = [
|
||||
"input_ids", # nlp models
|
||||
"input_values", # speech models
|
||||
]
|
||||
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
# first name of model_input_names has to correspond to main model input name
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
self.assertTrue(tokenizer.model_input_names[0] in accepted_model_main_input_names)
|
||||
|
||||
def test_rust_tokenizer_signature(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
@ -2419,43 +2463,20 @@ class TokenizerTesterMixin:
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
|
||||
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
|
||||
assert_padded_input_match(i_r, i_p, max_length)
|
||||
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
|
||||
pad_token_id = tokenizer_p.pad_token_id
|
||||
|
||||
# Encode - Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||||
|
||||
# Encode - Pair input
|
||||
input_r = tokenizer_r.encode(
|
||||
@ -2464,17 +2485,17 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||||
|
||||
# Encode_plus - Simple input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
@ -2483,7 +2504,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", max_length=max_length, padding="max_length"
|
||||
@ -2491,12 +2512,14 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
@ -2507,7 +2530,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
@ -2515,11 +2538,13 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Batch_encode_plus - Simple input
|
||||
@ -2533,7 +2558,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
pad_to_max_length=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"],
|
||||
@ -2545,7 +2570,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"],
|
||||
@ -2557,7 +2582,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding="longest"
|
||||
@ -2565,7 +2590,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding=True
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Batch_encode_plus - Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@ -2586,7 +2611,7 @@ class TokenizerTesterMixin:
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
@ -2602,7 +2627,7 @@ class TokenizerTesterMixin:
|
||||
],
|
||||
padding="longest",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
@ -2611,7 +2636,9 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
@ -2620,7 +2647,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@ -2633,7 +2660,7 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@ -2646,7 +2673,41 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
def test_padding_different_model_input_name(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
|
||||
pad_token_id = tokenizer_p.pad_token_id
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
|
||||
# rename encoded batch to "inputs"
|
||||
input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]]
|
||||
del input_r[tokenizer_r.model_input_names[0]]
|
||||
|
||||
input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]]
|
||||
del input_p[tokenizer_p.model_input_names[0]]
|
||||
|
||||
# Renaming `input_ids` to `inputs`
|
||||
tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:]
|
||||
tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:]
|
||||
|
||||
input_r = tokenizer_r.pad(input_r, padding="longest")
|
||||
input_p = tokenizer_r.pad(input_p, padding="longest")
|
||||
|
||||
max_length = len(input_p["inputs"][0])
|
||||
self.assert_batch_padded_input_match(
|
||||
input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs"
|
||||
)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
|
@ -174,3 +174,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
@ -128,3 +128,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
@ -107,6 +107,10 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user