mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Document Trainer limitation on custom models (#10635)
This commit is contained in:
parent
49c61a4ae7
commit
26a33cfd8c
@ -21,7 +21,7 @@ Before instantiating your :class:`~transformers.Trainer`/:class:`~transformers.T
|
||||
customization during training.
|
||||
|
||||
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
|
||||
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
||||
<https://github.com/NVIDIA/apex>`__ and Native AMP for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
||||
|
||||
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop which supports
|
||||
the above features. To inject custom behavior you can subclass them and override the following methods:
|
||||
@ -39,6 +39,18 @@ the above features. To inject custom behavior you can subclass them and override
|
||||
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
||||
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
||||
|
||||
.. warning::
|
||||
|
||||
The :class:`~transformers.Trainer` class is optimized for 🤗 Transformers models and can have surprising behaviors
|
||||
when you use it on other models. When using it on your own model, make sure:
|
||||
|
||||
- your model always return tuples or subclasses of :class:`~transformers.file_utils.ModelOutput`.
|
||||
- your model can compute the loss if a :obj:`labels` argument is provided and that loss is returned as the first
|
||||
element of the tuple (if your model returns tuples)
|
||||
- your model can accept multiple label arguments (use the :obj:`label_names` in your
|
||||
:class:`~transformers.TrainingArguments` to indicate their name to the :class:`~transformers.Trainer`) but none
|
||||
of them should be named :obj:`"label"`.
|
||||
|
||||
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label
|
||||
classification:
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user