mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-27 16:22:23 +06:00
101 lines
5.2 KiB
Python
101 lines
5.2 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Auto Model class. """
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import logging
|
|
|
|
from .tokenization_bert import BertTokenizer
|
|
from .tokenization_openai import OpenAIGPTTokenizer
|
|
from .tokenization_gpt2 import GPT2Tokenizer
|
|
from .tokenization_transfo_xl import TransfoXLTokenizer
|
|
from .tokenization_xlnet import XLNetTokenizer
|
|
from .tokenization_xlm import XLMTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AutoTokenizer(object):
|
|
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class
|
|
that will be instantiated as one of the tokenizer classes of the library
|
|
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
|
|
class method.
|
|
|
|
The `from_pretrained()` method take care of returning the correct tokenizer class instance
|
|
using pattern matching on the `pretrained_model_name_or_path` string.
|
|
|
|
The tokenizer class to instantiate is selected as the first pattern matching
|
|
in the `pretrained_model_name_or_path` string (in the following order):
|
|
- contains `bert`: BertTokenizer (Bert model)
|
|
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
|
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
|
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
|
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
|
- contains `xlm`: XLMTokenizer (XLM model)
|
|
|
|
This class cannot be instantiated using `__init__()` (throw an error).
|
|
"""
|
|
def __init__(self):
|
|
raise EnvironmentError("AutoTokenizer is designed to be instantiated "
|
|
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
|
r""" Instantiate a one of the tokenizer classes of the library
|
|
from a pre-trained model vocabulary.
|
|
|
|
The tokenizer class to instantiate is selected as the first pattern matching
|
|
in the `pretrained_model_name_or_path` string (in the following order):
|
|
- contains `bert`: BertTokenizer (Bert model)
|
|
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
|
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
|
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
|
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
|
- contains `xlm`: XLMTokenizer (XLM model)
|
|
|
|
Params:
|
|
**pretrained_model_name_or_path**: either:
|
|
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
|
|
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
|
|
- a path to a `directory` containing a configuration file saved
|
|
using the `save_pretrained(save_directory)` method.
|
|
- a path or url to a saved configuration `file`.
|
|
**cache_dir**: (`optional`) string:
|
|
Path to a directory in which a downloaded pre-trained model
|
|
configuration should be cached if the standard cache should not be used.
|
|
|
|
Examples::
|
|
|
|
>>> config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
|
|
>>> config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
|
|
|
|
"""
|
|
if 'bert' in pretrained_model_name_or_path:
|
|
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
elif 'openai-gpt' in pretrained_model_name_or_path:
|
|
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
elif 'gpt2' in pretrained_model_name_or_path:
|
|
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
elif 'transfo-xl' in pretrained_model_name_or_path:
|
|
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
elif 'xlnet' in pretrained_model_name_or_path:
|
|
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
elif 'xlm' in pretrained_model_name_or_path:
|
|
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
|
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
|
"'xlm'".format(pretrained_model_name_or_path))
|