mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix encode plus
This commit is contained in:
parent
030faccb8d
commit
3d57c51111
@ -916,7 +916,7 @@ class PreTrainedTokenizer(object):
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||
return_attention_mask: (optional) Set to False to avoir returning attention mask (default True)
|
||||
return_attention_mask: (optional) Set to False to avoid returning attention mask (default True)
|
||||
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
|
||||
|
||||
@ -961,24 +961,13 @@ class PreTrainedTokenizer(object):
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
|
||||
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
|
||||
# Prepare inputs as tensors if asked
|
||||
if return_tensors == 'tf' and is_tf_available():
|
||||
sequence = tf.constant([sequence])
|
||||
token_type_ids = tf.constant([token_type_ids])
|
||||
elif return_tensors == 'pt' and is_torch_available():
|
||||
sequence = torch.tensor([sequence])
|
||||
token_type_ids = torch.tensor([token_type_ids])
|
||||
elif return_tensors is not None:
|
||||
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
|
||||
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
@ -1015,10 +1004,9 @@ class PreTrainedTokenizer(object):
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
|
||||
elif self.padding_side == 'left':
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"]
|
||||
if return_special_tokens_mask:
|
||||
@ -1030,7 +1018,26 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
elif return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
||||
|
||||
|
||||
# Prepare inputs as tensors if asked
|
||||
if return_tensors == 'tf' and is_tf_available():
|
||||
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
|
||||
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]])
|
||||
|
||||
elif return_tensors == 'pt' and is_torch_available():
|
||||
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
|
||||
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = torch.tensor([encoded_inputs["attention_mask"]])
|
||||
elif return_tensors is not None:
|
||||
logger.warning(
|
||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||
return_tensors))
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
|
||||
|
Loading…
Reference in New Issue
Block a user