transformers/docs/source/en/main_classes/model.md
Steven Liu 096f304695
[docs] Big model loading (#29920)
* update

* feedback
2024-04-01 18:47:32 -07:00

2.2 KiB

Models

The base classes [PreTrainedModel], [TFPreTrainedModel], and [FlaxPreTrainedModel] implement the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS S3 repository).

[PreTrainedModel] and [TFPreTrainedModel] also implement a few methods which are common among all the models to:

  • resize the input token embeddings when new tokens are added to the vocabulary
  • prune the attention heads of the model.

The other methods that are common to each model are defined in [~modeling_utils.ModuleUtilsMixin] (for the PyTorch models) and [~modeling_tf_utils.TFModuleUtilsMixin] (for the TensorFlow models) or for text generation, [~generation.GenerationMixin] (for the PyTorch models), [~generation.TFGenerationMixin] (for the TensorFlow models) and [~generation.FlaxGenerationMixin] (for the Flax/JAX models).

PreTrainedModel

autodoc PreTrainedModel - push_to_hub - all

ModuleUtilsMixin

autodoc modeling_utils.ModuleUtilsMixin

TFPreTrainedModel

autodoc TFPreTrainedModel - push_to_hub - all

TFModelUtilsMixin

autodoc modeling_tf_utils.TFModelUtilsMixin

FlaxPreTrainedModel

autodoc FlaxPreTrainedModel - push_to_hub - all

Pushing to the Hub

autodoc utils.PushToHubMixin

Sharded checkpoints

autodoc modeling_utils.load_sharded_checkpoint