Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)

* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load

* Replaced torch.save with pickle.dump when saving the vocabulary

* updating transformer-xl

* uploaded on S3 - compatibility

* fix tests

* style

* Address review comments

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Piero Molino 2020-10-08 01:16:10 -07:00 committed by GitHub
parent aba4e22944
commit 4d04120c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 14 deletions

View File

@ -203,6 +203,19 @@ def is_faiss_available():
return _faiss_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
def is_sklearn_available():
return _has_sklearn

View File

@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available
from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@ -48,12 +48,16 @@ if is_torch_available():
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
VOCAB_FILES_NAMES = {
"pretrained_vocab_file": "vocab.pkl",
"pretrained_vocab_file_torch": "vocab.bin",
"vocab_file": "vocab.txt",
}
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"pretrained_vocab_file": {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
}
}
@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
File containing the vocabulary (from the original implementation).
pretrained_vocab_file (:obj:`str`, `optional`):
File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
never_split (xxx, `optional`):
Fill me with intesting stuff.
never_split (:obj:`List[str]`, `optional`):
List of tokens that should never be split. If no list is specified, will simply use the existing
special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
lower_case=False,
delimiter=None,
vocab_file=None,
pretrained_vocab_file=None,
pretrained_vocab_file: str = None,
never_split=None,
unk_token="<unk>",
eos_token="<eos>",
@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language)
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
# in a library like ours, at all.
try:
vocab_dict = None
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file)
# Priority on pickle files (support PyTorch and TF)
with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f)
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
# We therefore load it with torch, if it's available.
if type(vocab_dict) == int:
if not is_torch_available():
raise ImportError(
"Not trying to load dict with PyTorch as you need to install pytorch to load "
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
vocab_dict = torch.load(pretrained_vocab_file)
if vocab_dict is not None:
for key, value in vocab_dict.items():
if key not in self.__dict__:
self.__dict__[key] = value
if vocab_file is not None:
elif vocab_file is not None:
self.build_vocab()
except Exception:
except Exception as e:
raise ValueError(
"Unable to parse file {}. Unknown format. "
"If you tried to load a model saved through TransfoXLTokenizerFast,"
"please note they are not compatible.".format(pretrained_vocab_file)
)
) from e
if vocab_file is not None:
self.build_vocab()
@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else:
vocab_file = vocab_path
torch.save(self.__dict__, vocab_file)
with open(vocab_file, "wb") as f:
pickle.dump(self.__dict__, f)
return (vocab_file,)
def build_vocab(self):
@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
@torch_only_method
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose:
logger.info("encoding file {} ...".format(path))
@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded
@torch_only_method
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose:
logger.info("encoding {} sents ...".format(len(sents)))
@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
out_string = self.moses_detokenizer.detokenize(tokens)
return detokenize_numbers(out_string).strip()
@torch_only_method
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
@ -706,6 +732,7 @@ class LMShuffledIterator(object):
for idx in epoch_indices:
yield self.data[idx]
@torch_only_method
def stream_iterator(self, sent_stream):
# streams for each data in the batch
streams = [None] * self.bsz
@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
class TransfoXLCorpus(object):
@classmethod
@torch_only_method
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a pre-processed corpus.
@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs)
else:
data_iter = None
raise ValueError(f"Split not recognized: {split}")
return data_iter
@torch_only_method
def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, "cache.pkl")