mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Expand tutorial for custom models (#15587)
* Expand tutorial for custom models * Style * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
a86ee2261e
commit
c722753afd
@ -15,68 +15,219 @@ specific language governing permissions and limitations under the License.
|
||||
The 🤗 Transformers library is designed to be easily extensible. Every model is fully coded in a given subfolder
|
||||
of the repository with no abstraction, so you can easily copy a modeling file and tweak it to your needs.
|
||||
|
||||
Once you are happy with those tweaks and trained a model you want to share with the community, there are simple steps
|
||||
to push on the Model Hub not only the weights of your model, but also the code it relies on, so that anyone in the
|
||||
community can use it, even if it's not present in the 🤗 Transformers library.
|
||||
If you are writing a brand new model, it might be easier to start from scratch. In this tutorial, we will show you
|
||||
how to write a custom model and its configuration so it can be used inside Transformers, and how you can share it
|
||||
with the community (with the code it relies on) so that anyone can use it, even if it's not present in the 🤗
|
||||
Transformers library.
|
||||
|
||||
This also applies to configurations and tokenizers (support for feature extractors and processors is coming soon).
|
||||
We will illustrate all of this on a ResNet model, by wrapping the ResNet class of the
|
||||
[timm library](https://github.com/rwightman/pytorch-image-models/tree/master/timm) into a [`PreTrainedModel`].
|
||||
|
||||
## Writing a custom configuration
|
||||
|
||||
Before we dive into the model, let's first write its configuration. The configuration of a model is an object that
|
||||
will contain all the necessary information to build the model. As we will see in the next section, the model can only
|
||||
take a `config` to be initialized, so we really need that object to be as complete as possible.
|
||||
|
||||
In our example, we will take a couple of arguments of the ResNet class that we might want to tweak. Different
|
||||
configurations will then give us the different types of ResNets that are possible. We then just store those arguments,
|
||||
after checking the validity of a few of them.
|
||||
|
||||
```python
|
||||
from transformers import PretrainedConfig
|
||||
from typing import List
|
||||
|
||||
|
||||
class ResnetConfig(PretrainedConfig):
|
||||
model_type = "resnet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_type="bottleneck",
|
||||
layers: List[int] = [3, 4, 6, 3],
|
||||
num_classes: int = 1000,
|
||||
input_channels: int = 3,
|
||||
cardinality: int = 1,
|
||||
base_width: int = 64,
|
||||
stem_width: int = 64,
|
||||
stem_type: str = "",
|
||||
avg_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if block_type not in ["basic", "bottleneck"]:
|
||||
raise ValueError(f"`block` must be 'basic' or bottleneck', got {block}.")
|
||||
if stem_type not in ["", "deep", "deep-tiered"]:
|
||||
raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {block}.")
|
||||
|
||||
self.block_type = block_type
|
||||
self.layers = layers
|
||||
self.num_classes = num_classes
|
||||
self.input_channels = input_channels
|
||||
self.cardinality = cardinality
|
||||
self.base_width = base_width
|
||||
self.stem_width = stem_width
|
||||
self.stem_type = stem_type
|
||||
self.avg_down = avg_down
|
||||
super().__init__(**kwargs)
|
||||
```
|
||||
|
||||
The three important things to remember when writing you own configuration are the following:
|
||||
- you have to inherit from `PretrainedConfig`,
|
||||
- the `__init__` of your `PretrainedConfig` must accept any kwargs,
|
||||
- those `kwargs` need to be passed to the superclass `__init__`.
|
||||
|
||||
The inheritance is to make sure you get all the functionality from the 🤗 Transformers library, while the two other
|
||||
constraints come from the fact a `PretrainedConfig` has more fields than the ones you are setting. When reloading a
|
||||
config with the `from_pretrained` method, those fields need to be accepted by your config and then sent to the
|
||||
superclass.
|
||||
|
||||
Defining a `model_type` for your configuration (here `model_type="resnet"`) is not mandatory, unless you want to
|
||||
register your model with the auto classes (see last section).
|
||||
|
||||
With this done, you can easily create and save your configuration like you would do with any other model config of the
|
||||
library. Here is how we can create a resnet50d config and save it:
|
||||
|
||||
```py
|
||||
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
|
||||
resnet50d_config.save_pretrained("custom-resnet")
|
||||
```
|
||||
|
||||
This will save a file named `config.json` inside the folder `custom-resnet`. You can then reload your config with the
|
||||
`from_pretrained` method:
|
||||
|
||||
```py
|
||||
resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
|
||||
```
|
||||
|
||||
You can also use any other method of the [`PretrainedConfig`] class, like [`~PretrainedConfig.push_to_hub`] to
|
||||
directly upload your config to the Hub.
|
||||
|
||||
## Writing a custom model
|
||||
|
||||
Now that we have our ResNet configuration, we can go on writing the model. We will actually write two: one that
|
||||
extracts the hidden features from a batch of images (like [`BertModel`]) and one that is suitable for image
|
||||
classification (like [`BertModelForSequenceClassification`]).
|
||||
|
||||
As we mentioned before, we'll only write a loose wrapper of the model to keep it simple for this example. The only
|
||||
thing we need to do before writing this class is a map between the block types and actual block classes. Then the
|
||||
model is defined from the configuration by passing everything to the `ResNet` class:
|
||||
|
||||
```py
|
||||
from transformers import PreTrainedModel
|
||||
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
||||
from .configuration_resnet import ResnetConfig
|
||||
|
||||
|
||||
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
||||
|
||||
|
||||
class ResnetModel(PreTrainedModel):
|
||||
config_class = ResnetConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
block_layer = BLOCK_MAPPING[config.block_type]
|
||||
self.model = ResNet(
|
||||
block_layer,
|
||||
config.layers,
|
||||
num_classes=config.num_classes,
|
||||
in_chans=config.input_channels,
|
||||
cardinality=config.cardinality,
|
||||
base_width=config.base_width,
|
||||
stem_width=config.stem_width,
|
||||
stem_type=config.stem_type,
|
||||
avg_down=config.avg_down,
|
||||
)
|
||||
|
||||
def forward(self, tensor):
|
||||
return self.model.forward_features(tensor)
|
||||
```
|
||||
|
||||
For the model that will classify images, we just change the forward method:
|
||||
|
||||
```py
|
||||
class ResnetModelForImageClassification(PreTrainedModel):
|
||||
config_class = ResnetConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
block_layer = BLOCK_MAPPING[config.block_type]
|
||||
self.model = ResNet(
|
||||
block_layer,
|
||||
config.layers,
|
||||
num_classes=config.num_classes,
|
||||
in_chans=config.input_channels,
|
||||
cardinality=config.cardinality,
|
||||
base_width=config.base_width,
|
||||
stem_width=config.stem_width,
|
||||
stem_type=config.stem_type,
|
||||
avg_down=config.avg_down,
|
||||
)
|
||||
|
||||
def forward(self, tensor, labels=None):
|
||||
logits = self.model(tensor)
|
||||
if labels is not None:
|
||||
loss = torch.nn.cross_entropy(logits, labels)
|
||||
return {"loss": loss, "logits": logits}
|
||||
return {"logits": logits}
|
||||
```
|
||||
|
||||
In both cases, notice how we inherit from `PreTrainedModel` and call the superclass initialization with the `config`
|
||||
(a bit like when you write a regular `torch.nn.Module`). The line that sets the `config_class` is not mandatory, unless
|
||||
you want to register your model with the auto classes (see last section).
|
||||
|
||||
<Tip>
|
||||
|
||||
If your model is very similar to a model inside the library, you can re-use the same configuration as this model.
|
||||
|
||||
</Tip>
|
||||
|
||||
You can have your model return anything you want, but returning a dictionary like we did for
|
||||
`ResnetModelForImageClassification`, with the loss included when labels are passed, will make your model directly
|
||||
usable inside the [`Trainer`] class. Using another output format is fine as long as you are planning on using your own
|
||||
training loop or another library for training.
|
||||
|
||||
Now that we have our model class, let's create one:
|
||||
|
||||
```py
|
||||
resnet50d = ResnetModelForImageClassification(resnet50d_config)
|
||||
```
|
||||
|
||||
Again, you can use any of the methods of [`PreTrainedModel`], like [`~PreTrainedModel.save_pretrained`] or
|
||||
[`~PreTrainedModel.push_to_hub`]. We will use the second in the next section, and see how to push the model weights
|
||||
with the code of our model. But first, let's load some pretrained weights inside our model.
|
||||
|
||||
In your own use case, you will probably be training your custom model on your own data. To go fast for this tutorial,
|
||||
we will use the pretrained version of the resnet50d. Since our model is just a wrapper around it, it's going to be
|
||||
easy to transfer those weights:
|
||||
|
||||
```py
|
||||
import timm
|
||||
|
||||
pretrained_model = timm.create_model("resnet50d", pretrained=True)
|
||||
resnet50d.model.load_state_dict(pretrained_model.state_dict())
|
||||
```
|
||||
|
||||
Now let's see how to make sure that when we do [`~PreTrainedModel.save_pretrained`] or [`~PreTrainedModel.push_to_hub`], the
|
||||
code of the model is saved.
|
||||
|
||||
## Sending the code to the Hub
|
||||
|
||||
First, make sure your model is fully defined in a `.py` file. It can rely on relative imports to some other files as
|
||||
long as all the files are in the same directory (we don't support submodules for this feature yet). For instance,
|
||||
let's say you have a `modeling.py` file and a `configuration.py` file in a folder of the current working directory
|
||||
named `awesome_model`, and that the modeling file defines an `AwesomeModel`, the configuration file a `AwesomeConfig`.
|
||||
long as all the files are in the same directory (we don't support submodules for this feature yet). For our example,
|
||||
we'll define a `modeling_resnet.py` file and a `configuration_resnet.py` file in a folder of the current working
|
||||
directory named `resnet_model`. The configuration file contains the code for `ResnetConfig` and the modeling file
|
||||
contains the code of `ResnetModel` and `ResnetModelForImageClassification`.
|
||||
|
||||
```
|
||||
.
|
||||
└── awesome_model
|
||||
└── resnet_model
|
||||
├── __init__.py
|
||||
├── configuration.py
|
||||
└── modeling.py
|
||||
├── configuration_resnet.py
|
||||
└── modeling_resnet.py
|
||||
```
|
||||
|
||||
The `__init__.py` can be empty, it's just there so that Python detects `awesome_model` can be use as a module.
|
||||
Here is an example of what the configuration file could look like:
|
||||
|
||||
```py
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class AwesomeConfig(PretrainedConfig):
|
||||
model_type = "awesome"
|
||||
|
||||
def __init__(self, attribute=1, hidden_size=42, **kwargs):
|
||||
self.attribute = attribute
|
||||
self.hidden_size = hidden_size
|
||||
super().__init__(**kwargs)
|
||||
```
|
||||
|
||||
and the modeling file could have content like this:
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .configuration import AwesomeConfig
|
||||
|
||||
|
||||
class AwesomeModel(PreTrainedModel):
|
||||
config_class = AwesomeConfig
|
||||
base_model_prefix = "base"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
```
|
||||
|
||||
`AwesomeModel` should subclass [`PreTrainedModel`] and `AwesomeConfig` should subclass [`PretrainedConfig`]. The
|
||||
easiest way to achieve this is to copy the modeling and configuration files of the model closest to the one you're
|
||||
coding, and then tweaking them.
|
||||
The `__init__.py` can be empty, it's just there so that Python detects `resnet_model` can be use as a module.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@ -87,51 +238,44 @@ to import from the `transformers` package.
|
||||
|
||||
Note that you can re-use (or subclass) an existing configuration/model.
|
||||
|
||||
To share your model with the community, follow those steps: first import the custom objects.
|
||||
To share your model with the community, follow those steps: first import the ResNet model and config from the newly
|
||||
created files:
|
||||
|
||||
```py
|
||||
from awesome_model.configuration import AwesomeConfig
|
||||
from awesome_model.modeling import AwesomeModel
|
||||
from resnet_model.configuration_resnet import ResnetConfig
|
||||
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
|
||||
```
|
||||
|
||||
Then you have to tell the library you want to copy the code files of those objects when using the `save_pretrained`
|
||||
method and properly register them with a given Auto class (especially for models), just run:
|
||||
|
||||
```py
|
||||
AwesomeConfig.register_for_auto_class()
|
||||
AwesomeModel.register_for_auto_class("AutoModel")
|
||||
ResnetConfig.register_for_auto_class()
|
||||
ResnetModel.register_for_auto_class("AutoModel")
|
||||
ResnetModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")
|
||||
```
|
||||
|
||||
Note that there is no need to specify an auto class for the configuration (there is only one auto class for them,
|
||||
[`AutoConfig`]) but it's different for models. Your custom model could be suitable for sequence classification (in
|
||||
which case you should do `AwesomeModel.register_for_auto_class("AutoModelForSequenceClassification")`) or any other
|
||||
task, so you have to specify which one of the auto classes is the correct one for your model.
|
||||
[`AutoConfig`]) but it's different for models. Your custom model could be suitable for many different tasks, so you
|
||||
have to specify which one of the auto classes is the correct one for your model.
|
||||
|
||||
Next, just create the config and models as you would any other Transformer models:
|
||||
Next, let's create the config and models as we did before:
|
||||
|
||||
```py
|
||||
config = AwesomeConfig()
|
||||
model = AwesomeModel(config)
|
||||
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
|
||||
resnet50d = ResnetModelForImageClassification(resnet50d_config)
|
||||
|
||||
pretrained_model = timm.create_model("resnet50d", pretrained=True)
|
||||
resnet50d.model.load_state_dict(pretrained_model.state_dict())
|
||||
```
|
||||
|
||||
then train your model. Alternatively, you could load a pretrained checkpoint you have already trained in your model.
|
||||
|
||||
Once everything is ready, you just have to do:
|
||||
|
||||
```py
|
||||
model.save_pretrained("save_dir")
|
||||
```
|
||||
|
||||
which will not only save the model weights and the configuration in json format, but also copy the modeling and
|
||||
configuration `.py` files in this folder, so you can directly upload the result to the Hub.
|
||||
|
||||
If you have already logged in to Hugging face with
|
||||
Now to send the model to the Hub, make sure you are logged in. Either run in your terminal:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
or in a notebook with
|
||||
or from a notebook:
|
||||
|
||||
```py
|
||||
from huggingface_hub import notebook_login
|
||||
@ -139,11 +283,15 @@ from huggingface_hub import notebook_login
|
||||
notebook_login()
|
||||
```
|
||||
|
||||
you can push your model and its code to the Hub with the following:
|
||||
You can then push to to your own namespace (or an organization you are a member of) like this:
|
||||
|
||||
```py
|
||||
model.push_to_hub("model-identifier")
|
||||
```
|
||||
resnet50d.push_to_hub("custom-resnet50d")
|
||||
```
|
||||
|
||||
On top of the modeling weights and the configuration in json format, this also copied the modeling and
|
||||
configuration `.py` files in the folder `custom-resnet50d` and uploaded the result to the Hub. You can check the result
|
||||
in this [model repo](https://huggingface.co/sgugger/custom-resnet50d).
|
||||
|
||||
See the [sharing tutorial](model_sharing) for more information on the push to Hub method.
|
||||
|
||||
@ -154,18 +302,41 @@ the `from_pretrained` method. The only thing is that you have to add an extra ar
|
||||
online code and trust the author of that model, to avoid executing malicious code on your machine:
|
||||
|
||||
```py
|
||||
from transformers import AutoModel
|
||||
from transformers import AutoModelForImageClassification
|
||||
|
||||
model = AutoModel.from_pretrained("model-checkpoint", trust_remote_code=True)
|
||||
model = AutoModelForImageClassification.from_pretrained("sgugger/custom-resnet50d", trust_remote_code=True)
|
||||
```
|
||||
|
||||
It is also strongly encouraged to pass a commit hash as a `revision` to make sure the author of the models did not
|
||||
update the code with some malicious new lines (unless you fully trust the authors of the models).
|
||||
|
||||
```py
|
||||
commit_hash = "b731e5fae6d80a4a775461251c4388886fb7a249"
|
||||
model = AutoModel.from_pretrained("model-checkpoint", trust_remote_code=True, revision=commit_hash)
|
||||
commit_hash = "ed94a7c6247d8aedce4647f00f20de6875b5b292"
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"sgugger/custom-resnet50d", trust_remote_code=True, revision=commit_hash
|
||||
)
|
||||
```
|
||||
|
||||
Note that when browsing the commit history of the model repo on the Hub, there is a button to easily copy the commit
|
||||
hash of any commit.
|
||||
|
||||
## Registering a model with custom code to the auto classes
|
||||
|
||||
If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own
|
||||
model. This is different from pushing the code to the Hub in the sense that users will need to import your library to
|
||||
get the custom models (contrarily to automatically downloading the model code from the Hub).
|
||||
|
||||
As long as your config has a `model_type` attribute that is different from existing model types, and that your model
|
||||
classes have the right `config_class` attributes, you can just add them to the auto classes likes this:
|
||||
|
||||
```py
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
|
||||
|
||||
AutoConfig.register("resnet", ResnetConfig)
|
||||
AutoModel.register(ResnetConfig, ResnetModel)
|
||||
AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)
|
||||
```
|
||||
|
||||
Note that the first argument used when registering your custom config to [`AutoConfig`] needs to match the `model_type`
|
||||
of your custom config, and the first argument used when registering your custom models to any auto model class needs
|
||||
to match the `config_class` of those models.
|
||||
|
Loading…
Reference in New Issue
Block a user