mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Added RoBERTa to AutoModel/AutoConfig
This commit is contained in:
parent
fe02e45e48
commit
e24e19ce3b
@ -29,6 +29,7 @@ from .modeling_gpt2 import GPT2Config, GPT2Model
|
|||||||
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
|
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
|
||||||
from .modeling_xlnet import XLNetConfig, XLNetModel
|
from .modeling_xlnet import XLNetConfig, XLNetModel
|
||||||
from .modeling_xlm import XLMConfig, XLMModel
|
from .modeling_xlm import XLMConfig, XLMModel
|
||||||
|
from .modeling_roberta import RobertaConfig, RobertaModel
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
@ -51,6 +52,7 @@ class AutoConfig(object):
|
|||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetConfig (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMConfig (XLM model)
|
||||||
|
- contains `roberta`: RobertaConfig (RoBERTa model)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throw an error).
|
This class cannot be instantiated using `__init__()` (throw an error).
|
||||||
"""
|
"""
|
||||||
@ -71,6 +73,7 @@ class AutoConfig(object):
|
|||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetConfig (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMConfig (XLM model)
|
||||||
|
- contains `roberta`: RobertaConfig (RoBERTa model)
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
**pretrained_model_name_or_path**: either:
|
**pretrained_model_name_or_path**: either:
|
||||||
@ -119,6 +122,8 @@ class AutoConfig(object):
|
|||||||
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif 'xlm' in pretrained_model_name_or_path:
|
||||||
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
|
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
@ -137,12 +142,13 @@ class AutoModel(object):
|
|||||||
|
|
||||||
The base model class to instantiate is selected as the first pattern matching
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `bert`: BertConfig (Bert model)
|
- contains `bert`: BertModel (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throw an error).
|
This class cannot be instantiated using `__init__()` (throw an error).
|
||||||
"""
|
"""
|
||||||
@ -157,12 +163,13 @@ class AutoModel(object):
|
|||||||
|
|
||||||
The base model class to instantiate is selected as the first pattern matching
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `bert`: BertConfig (Bert model)
|
- contains `bert`: BertModel (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
To train the model, you should first set it back in training mode with `model.train()`
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
@ -230,6 +237,8 @@ class AutoModel(object):
|
|||||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif 'xlm' in pretrained_model_name_or_path:
|
||||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
|
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
|
Loading…
Reference in New Issue
Block a user