diff --git a/src/transformers/commands/add_new_model.py b/src/transformers/commands/add_new_model.py index 6cab10de0e5..23270fd6b22 100644 --- a/src/transformers/commands/add_new_model.py +++ b/src/transformers/commands/add_new_model.py @@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): path_to_transformer_root = ( Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent ) - path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter" + path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model" # Execute cookiecutter if not self._testing: @@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand): output_pytorch = "PyTorch" in pytorch_or_tensorflow output_tensorflow = "TensorFlow" in pytorch_or_tensorflow + model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}" + os.makedirs(model_dir, exist_ok=True) + + shutil.move( + f"{directory}/__init__.py", + f"{model_dir}/__init__.py", + ) shutil.move( f"{directory}/configuration_{lowercase_model_name}.py", - f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py", + f"{model_dir}/configuration_{lowercase_model_name}.py", ) def remove_copy_lines(path): @@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): shutil.move( f"{directory}/modeling_{lowercase_model_name}.py", - f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py", + f"{model_dir}/modeling_{lowercase_model_name}.py", ) shutil.move( @@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): shutil.move( f"{directory}/modeling_tf_{lowercase_model_name}.py", - f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py", + f"{model_dir}/modeling_tf_{lowercase_model_name}.py", ) shutil.move( @@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): shutil.move( f"{directory}/tokenization_{lowercase_model_name}.py", - f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py", + f"{model_dir}/tokenization_{lowercase_model_name}.py", ) from os import fdopen, remove diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b056dc2790d..4a0a254d5c5 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -21,6 +21,8 @@ from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings from ...utils import logging + +# Add modeling imports here from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, @@ -228,8 +230,6 @@ from .configuration_auto import ( ) -# Add modeling imports here - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 291b3243079..b43f15947a0 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -21,6 +21,8 @@ from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings from ...utils import logging + +# Add modeling imports here from ..albert.modeling_tf_albert import ( TFAlbertForMaskedLM, TFAlbertForMultipleChoice, @@ -175,8 +177,6 @@ from .configuration_auto import ( ) -# Add modeling imports here - logger = logging.get_logger(__name__) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py new file mode 100644 index 00000000000..b78052af1bb --- /dev/null +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py @@ -0,0 +1,43 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %} +from ...file_utils import is_tf_available, is_torch_available +{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %} +from ...file_utils import is_torch_available +{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %} +from ...file_utils import is_tf_available +{% endif %} +from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config +from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer + +{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %} +if is_torch_available(): + from .modeling_{{cookiecutter.lowercase_modelname}} import ( + {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, + {{cookiecutter.camelcase_modelname}}ForMaskedLM, + {{cookiecutter.camelcase_modelname}}ForMultipleChoice, + {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, + {{cookiecutter.camelcase_modelname}}ForSequenceClassification, + {{cookiecutter.camelcase_modelname}}ForTokenClassification, + {{cookiecutter.camelcase_modelname}}Layer, + {{cookiecutter.camelcase_modelname}}Model, + {{cookiecutter.camelcase_modelname}}PreTrainedModel, + load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, + ) +{% endif %} +{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %} +if is_tf_available(): + from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( + TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, + TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, + TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, + TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, + TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification, + TF{{cookiecutter.camelcase_modelname}}ForTokenClassification, + TF{{cookiecutter.camelcase_modelname}}Layer, + TF{{cookiecutter.camelcase_modelname}}Model, + TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, + ) +{% endif %} \ No newline at end of file diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py index 2e064271113..8fe8cb6b494 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py @@ -14,8 +14,8 @@ # limitations under the License. """ {{cookiecutter.modelname}} model configuration """ -from .configuration_utils import PretrainedConfig -from .utils import logging +from ...configuration_utils import PretrainedConfig +from ...utils import logging logger = logging.get_logger(__name__) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 5b20528f72c..b4eaacb2da6 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -17,15 +17,14 @@ import tensorflow as tf -from .activations_tf import get_tf_activation -from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config -from .file_utils import ( +from ...activations_tf import get_tf_activation +from ...file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) -from .modeling_tf_outputs import ( +from ...modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, TFMaskedLMOutput, @@ -34,7 +33,7 @@ from .modeling_tf_outputs import ( TFSequenceClassifierOutput, TFTokenClassifierOutput, ) -from .modeling_tf_utils import ( +from ...modeling_tf_utils import ( TFMaskedLanguageModelingLoss, TFMultipleChoiceLoss, TFPreTrainedModel, @@ -46,8 +45,9 @@ from .modeling_tf_utils import ( keras_serializable, shape_list, ) -from .tokenization_utils import BatchEncoding -from .utils import logging +from ...tokenization_utils import BatchEncoding +from ...utils import logging +from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config logger = logging.get_logger(__name__) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index fb8593e61be..6036f8bc4e6 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -25,13 +25,12 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, MSELoss -from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config -from .file_utils import ( +from ...file_utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) -from .modeling_outputs import ( +from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, @@ -40,15 +39,16 @@ from .modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from .modeling_utils import ( +from ...modeling_utils import ( PreTrainedModel, SequenceSummary, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) -from .utils import logging -from .activations import ACT2FN +from ...utils import logging +from ...activations import ACT2FN +from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config logger = logging.get_logger(__name__) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py index 16ee5916980..943fcd39a1c 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py @@ -14,7 +14,7 @@ # To replace in: "src/transformers/__init__.py" # Below: "if is_torch_available():" if generating PyTorch # Replace with: - from .modeling_{{cookiecutter.lowercase_modelname}} import ( + from .models.{{cookiecutter.lowercase_modelname}} import ( {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMultipleChoice, @@ -30,7 +30,7 @@ # Below: "if is_tf_available():" if generating TensorFlow # Replace with: - from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( + from .models.{{cookiecutter.lowercase_modelname}} import ( TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, @@ -44,14 +44,14 @@ # End. -# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" +# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" # Replace with: -from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config +from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config # End. -# To replace in: "src/transformers/configuration_auto.py" +# To replace in: "src/transformers/models/auto/configuration_auto.py" # Below: "# Add configs here" # Replace with: ("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config), @@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, # End. -# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig", +# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig", # Replace with: -from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config +from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config # End. # Below: "# Add full (and cased) model names here" @@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u # Below: "# Add modeling imports here" # Replace with: -from .modeling_{{cookiecutter.lowercase_modelname}} import ( +from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import ( {{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, @@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import ( # Below: "# Add modeling imports here" # Replace with: -from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( +from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import ( TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py index 29a762961e8..8dcbb1e9b39 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py @@ -15,8 +15,9 @@ """Tokenization classes for {{cookiecutter.modelname}}.""" {%- if cookiecutter.tokenizer_type == "Based on BERT" %} -from .tokenization_bert import BertTokenizer, BertTokenizerFast -from .utils import logging +from ...utils import logging +from ..bert.tokenization_bert import BertTokenizer +from ..bert.tokenization_bert_fast import BertTokenizerFast logger = logging.get_logger(__name__) @@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast): pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION {%- elif cookiecutter.tokenizer_type == "Standalone" %} import warnings +from typing import List, Optional from tokenizers import ByteLevelBPETokenizer -from .tokenization_utils import AddedToken, PreTrainedTokenizer -from .tokenization_utils_base import BatchEncoding -from .tokenization_utils_fast import PreTrainedTokenizerFast -from typing import List, Optional -from .utils import logging +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging logger = logging.get_logger(__name__)