clean up __init__

This commit is contained in:
thomwolf 2019-09-25 21:04:52 +02:00
parent 3b7fb48c3b
commit 8a618e0af5
2 changed files with 85 additions and 62 deletions

View File

@ -16,7 +16,21 @@ import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Tokenizer
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels)
if is_sklearn_available():
from .data import glue_compute_metrics
# Tokenizers
from .tokenization_utils import (PreTrainedTokenizer)
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
try:
import torch
_torch_available = True # pylint: disable=invalid-name
except ImportError:
_torch_available = False # pylint: disable=invalid-name
if _torch_available:
if is_torch_available():
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
@ -87,14 +95,7 @@ if _torch_available:
# TensorFlow
try:
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
if _tf_available:
if is_tf_available():
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
@ -151,7 +152,8 @@ if _tf_available:
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
if _tf_available and _torch_available:
# TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available():
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model,
@ -159,17 +161,3 @@ if _tf_available and _torch_available:
load_tf2_checkpoint_in_pytorch_model,
load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels)
if is_sklearn_available():
from .data import glue_compute_metrics

View File

@ -23,7 +23,7 @@ import six
import copy
from io import open
from .file_utils import cached_path, is_tf_available
from .file_utils import cached_path, is_tf_available, is_torch_available
if is_tf_available():
import tensorflow as tf
@ -690,39 +690,20 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
**kwargs: passed to the `self.tokenize()` method
"""
encoded_inputs = self.encode_plus(text, text_pair=text_pair, add_special_tokens=add_special_tokens, **kwargs)
return encoded_inputs["input_ids"]
def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
@ -738,6 +719,51 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
"""
encoded_inputs = self.encode_plus(text,
text_pair=text_pair,
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors,
**kwargs)
return encoded_inputs["input_ids"]
def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
"""
@ -759,10 +785,12 @@ class PreTrainedTokenizer(object):
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence)
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, truncate_first_sequence=True):
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
@ -782,6 +810,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
Return:
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
@ -816,6 +846,11 @@ class PreTrainedTokenizer(object):
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
if return_tensors == 'tf' and is_tf_available():
sequence = tf.constant(sequence)
token_type_ids = tf.constant(token_type_ids)
elif return_tensors = 'pt' and is
encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids