mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fixed loading pre-trained tokenizer from directory
This commit is contained in:
parent
532a81d3d6
commit
d6f06c03f4
@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name))
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
|
@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
}
|
||||
VOCAB_NAME = 'vocab.txt'
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
@ -100,7 +101,7 @@ class BertTokenizer(object):
|
||||
return tokens
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@ -109,16 +110,11 @@ class BertTokenizer(object):
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
vocab_file = pretrained_model_name
|
||||
if os.path.isdir(vocab_file):
|
||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file)
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, do_lower_case)
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
@ -126,8 +122,15 @@ class BertTokenizer(object):
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name))
|
||||
tokenizer = None
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user