diff --git a/docs/source/en/custom_models.md b/docs/source/en/custom_models.md index 4abc3ce5773..22ba58b9d9d 100644 --- a/docs/source/en/custom_models.md +++ b/docs/source/en/custom_models.md @@ -14,7 +14,7 @@ rendered properly in your Markdown viewer. --> -# Sharing custom models +# Building custom models The 🤗 Transformers library is designed to be easily extensible. Every model is fully coded in a given subfolder of the repository with no abstraction, so you can easily copy a modeling file and tweak it to your needs. @@ -22,7 +22,8 @@ of the repository with no abstraction, so you can easily copy a modeling file an If you are writing a brand new model, it might be easier to start from scratch. In this tutorial, we will show you how to write a custom model and its configuration so it can be used inside Transformers, and how you can share it with the community (with the code it relies on) so that anyone can use it, even if it's not present in the 🤗 -Transformers library. +Transformers library. We'll see how to build upon transformers and extend the framework with your hooks and +custom code. We will illustrate all of this on a ResNet model, by wrapping the ResNet class of the [timm library](https://github.com/rwightman/pytorch-image-models) into a [`PreTrainedModel`]. @@ -218,6 +219,27 @@ resnet50d.model.load_state_dict(pretrained_model.state_dict()) Now let's see how to make sure that when we do [`~PreTrainedModel.save_pretrained`] or [`~PreTrainedModel.push_to_hub`], the code of the model is saved. +## Registering a model with custom code to the auto classes + +If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own +model. This is different from pushing the code to the Hub in the sense that users will need to import your library to +get the custom models (contrarily to automatically downloading the model code from the Hub). + +As long as your config has a `model_type` attribute that is different from existing model types, and that your model +classes have the right `config_class` attributes, you can just add them to the auto classes like this: + +```py +from transformers import AutoConfig, AutoModel, AutoModelForImageClassification + +AutoConfig.register("resnet", ResnetConfig) +AutoModel.register(ResnetConfig, ResnetModel) +AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification) +``` + +Note that the first argument used when registering your custom config to [`AutoConfig`] needs to match the `model_type` +of your custom config, and the first argument used when registering your custom models to any auto model class needs +to match the `config_class` of those models. + ## Sending the code to the Hub @@ -350,23 +372,3 @@ model = AutoModelForImageClassification.from_pretrained( Note that when browsing the commit history of the model repo on the Hub, there is a button to easily copy the commit hash of any commit. -## Registering a model with custom code to the auto classes - -If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own -model. This is different from pushing the code to the Hub in the sense that users will need to import your library to -get the custom models (contrarily to automatically downloading the model code from the Hub). - -As long as your config has a `model_type` attribute that is different from existing model types, and that your model -classes have the right `config_class` attributes, you can just add them to the auto classes like this: - -```py -from transformers import AutoConfig, AutoModel, AutoModelForImageClassification - -AutoConfig.register("resnet", ResnetConfig) -AutoModel.register(ResnetConfig, ResnetModel) -AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification) -``` - -Note that the first argument used when registering your custom config to [`AutoConfig`] needs to match the `model_type` -of your custom config, and the first argument used when registering your custom models to any auto model class needs -to match the `config_class` of those models.