mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)
* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load * Replaced torch.save with pickle.dump when saving the vocabulary * updating transformer-xl * uploaded on S3 - compatibility * fix tests * style * Address review comments Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
aba4e22944
commit
4d04120c6d
@ -203,6 +203,19 @@ def is_faiss_available():
|
||||
return _faiss_available
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
raise ImportError(
|
||||
"You need to install pytorch to use this method or class, "
|
||||
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
||||
)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
|
||||
|
@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
|
||||
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
|
||||
from tokenizers.processors import BertProcessing
|
||||
|
||||
from .file_utils import cached_path, is_torch_available
|
||||
from .file_utils import cached_path, is_torch_available, torch_only_method
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from .utils import logging
|
||||
@ -48,12 +48,16 @@ if is_torch_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
|
||||
VOCAB_FILES_NAMES = {
|
||||
"pretrained_vocab_file": "vocab.pkl",
|
||||
"pretrained_vocab_file_torch": "vocab.bin",
|
||||
"vocab_file": "vocab.txt",
|
||||
}
|
||||
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"pretrained_vocab_file": {
|
||||
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
|
||||
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
File containing the vocabulary (from the original implementation).
|
||||
pretrained_vocab_file (:obj:`str`, `optional`):
|
||||
File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
|
||||
never_split (xxx, `optional`):
|
||||
Fill me with intesting stuff.
|
||||
never_split (:obj:`List[str]`, `optional`):
|
||||
List of tokens that should never be split. If no list is specified, will simply use the existing
|
||||
special tokens.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
lower_case=False,
|
||||
delimiter=None,
|
||||
vocab_file=None,
|
||||
pretrained_vocab_file=None,
|
||||
pretrained_vocab_file: str = None,
|
||||
never_split=None,
|
||||
unk_token="<unk>",
|
||||
eos_token="<eos>",
|
||||
@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
self.moses_tokenizer = sm.MosesTokenizer(language)
|
||||
self.moses_detokenizer = sm.MosesDetokenizer(language)
|
||||
|
||||
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
|
||||
# in a library like ours, at all.
|
||||
try:
|
||||
vocab_dict = None
|
||||
if pretrained_vocab_file is not None:
|
||||
# Hack because, honestly this tokenizer was not made to be used
|
||||
# in a library like ours, at all.
|
||||
vocab_dict = torch.load(pretrained_vocab_file)
|
||||
# Priority on pickle files (support PyTorch and TF)
|
||||
with open(pretrained_vocab_file, "rb") as f:
|
||||
vocab_dict = pickle.load(f)
|
||||
|
||||
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
|
||||
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
|
||||
# We therefore load it with torch, if it's available.
|
||||
if type(vocab_dict) == int:
|
||||
if not is_torch_available():
|
||||
raise ImportError(
|
||||
"Not trying to load dict with PyTorch as you need to install pytorch to load "
|
||||
"from a PyTorch pretrained vocabulary, "
|
||||
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
||||
)
|
||||
vocab_dict = torch.load(pretrained_vocab_file)
|
||||
|
||||
if vocab_dict is not None:
|
||||
for key, value in vocab_dict.items():
|
||||
if key not in self.__dict__:
|
||||
self.__dict__[key] = value
|
||||
|
||||
if vocab_file is not None:
|
||||
elif vocab_file is not None:
|
||||
self.build_vocab()
|
||||
except Exception:
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Unable to parse file {}. Unknown format. "
|
||||
"If you tried to load a model saved through TransfoXLTokenizerFast,"
|
||||
"please note they are not compatible.".format(pretrained_vocab_file)
|
||||
)
|
||||
) from e
|
||||
|
||||
if vocab_file is not None:
|
||||
self.build_vocab()
|
||||
@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
|
||||
else:
|
||||
vocab_file = vocab_path
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
with open(vocab_file, "wb") as f:
|
||||
pickle.dump(self.__dict__, f)
|
||||
return (vocab_file,)
|
||||
|
||||
def build_vocab(self):
|
||||
@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
|
||||
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
|
||||
|
||||
@torch_only_method
|
||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
|
||||
if verbose:
|
||||
logger.info("encoding file {} ...".format(path))
|
||||
@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return encoded
|
||||
|
||||
@torch_only_method
|
||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||
if verbose:
|
||||
logger.info("encoding {} sents ...".format(len(sents)))
|
||||
@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
out_string = self.moses_detokenizer.detokenize(tokens)
|
||||
return detokenize_numbers(out_string).strip()
|
||||
|
||||
@torch_only_method
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||
|
||||
@ -706,6 +732,7 @@ class LMShuffledIterator(object):
|
||||
for idx in epoch_indices:
|
||||
yield self.data[idx]
|
||||
|
||||
@torch_only_method
|
||||
def stream_iterator(self, sent_stream):
|
||||
# streams for each data in the batch
|
||||
streams = [None] * self.bsz
|
||||
@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
|
||||
|
||||
class TransfoXLCorpus(object):
|
||||
@classmethod
|
||||
@torch_only_method
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a pre-processed corpus.
|
||||
@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
|
||||
data_iter = LMOrderedIterator(data, *args, **kwargs)
|
||||
elif self.dataset == "lm1b":
|
||||
data_iter = LMShuffledIterator(data, *args, **kwargs)
|
||||
else:
|
||||
data_iter = None
|
||||
raise ValueError(f"Split not recognized: {split}")
|
||||
|
||||
return data_iter
|
||||
|
||||
|
||||
@torch_only_method
|
||||
def get_lm_corpus(datadir, dataset):
|
||||
fn = os.path.join(datadir, "cache.pt")
|
||||
fn_pickle = os.path.join(datadir, "cache.pkl")
|
||||
|
Loading…
Reference in New Issue
Block a user