Padding strategy (left and right) rather than boolean flag

This commit is contained in:
LysandreJik 2019-11-21 11:30:40 -05:00
parent 9f374c8252
commit a7dafe2f41
2 changed files with 77 additions and 24 deletions

View File

@ -343,21 +343,33 @@ class CommonTestCases:
padding_size = 10
padding_idx = tokenizer.pad_token_id
# Check that it correctly pads when a maximum length is specified along with the padding flag set to True
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='right')
padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
# Check that nothing is done when a maximum length is not specified
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='left')
padded_sequence_length = len(padded_sequence)
assert sequence_length == padded_sequence_length
assert encoded_sequence == padded_sequence
assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence_right = tokenizer.encode(sequence, padding_strategy='right')
padded_sequence_right_length = len(padded_sequence_right)
padded_sequence_left = tokenizer.encode(sequence, padding_strategy='left')
padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length
assert encoded_sequence == padded_sequence_left
def test_encode_plus_with_padding(self):
tokenizer = self.get_tokenizer()
@ -374,7 +386,8 @@ class CommonTestCases:
special_tokens_mask = encoded_sequence['special_tokens_mask']
sequence_length = len(input_ids)
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
# Test right padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='right', return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask']
@ -385,4 +398,18 @@ class CommonTestCases:
assert input_ids + [padding_idx] * padding_size == padded_input_ids
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
assert attention_mask + [0] * padding_size == padded_attention_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
# Test left padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='left', return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask']
padded_special_tokens_mask = padded_sequence['special_tokens_mask']
padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + input_ids == padded_input_ids
assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids
assert [0] * padding_size + attention_mask == padded_attention_mask
assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask

View File

@ -702,7 +702,7 @@ class PreTrainedTokenizer(object):
max_length=None,
stride=0,
truncation_strategy='longest_first',
pad_to_max_length=False,
padding_strategy=None,
return_tensors=None,
**kwargs):
"""
@ -729,8 +729,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
The strategies are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to None: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
@ -741,7 +745,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens,
stride=stride,
truncation_strategy=truncation_strategy,
pad_to_max_length=pad_to_max_length,
padding_strategy=padding_strategy,
return_tensors=return_tensors,
**kwargs)
@ -754,7 +758,7 @@ class PreTrainedTokenizer(object):
max_length=None,
stride=0,
truncation_strategy='longest_first',
pad_to_max_length=False,
padding_strategy=None,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
@ -784,8 +788,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
The strategies are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to None: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
@ -833,7 +841,7 @@ class PreTrainedTokenizer(object):
return self.prepare_for_model(first_ids,
pair_ids=second_ids,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
padding_strategy=padding_strategy,
add_special_tokens=add_special_tokens,
stride=stride,
truncation_strategy=truncation_strategy,
@ -845,7 +853,7 @@ class PreTrainedTokenizer(object):
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first',
pad_to_max_length=False,
padding_strategy=None,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
@ -873,8 +881,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
The strategies are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to None: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
@ -943,16 +955,30 @@ class PreTrainedTokenizer(object):
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
if pad_to_max_length and max_length and len(encoded_inputs["input_ids"]) < max_length:
if padding_strategy is not None and max_length and len(encoded_inputs["input_ids"]) < max_length:
difference = max_length - len(encoded_inputs["input_ids"])
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] += [self.pad_token_type_id] * difference
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] += [1] * difference
encoded_inputs["input_ids"] += [self.pad_token_id] * difference
if padding_strategy == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
elif padding_strategy == 'left':
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
if return_token_type_ids:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"]
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
else:
raise ValueError("Invalid padding strategy:" + str(padding_strategy))
elif return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])