Add an option to reduce compile() console spam (#23938)

* Add an option to reduce compile() console spam

* Add annotations to the example scripts

* Add notes to the quicktour docs as well

* minor fix
This commit is contained in:
Matt 2023-06-02 15:28:52 +01:00 committed by GitHub
parent c9cf337772
commit 167a0d8f87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 54 additions and 31 deletions

View File

@ -532,12 +532,12 @@ All models are a standard [`tf.keras.Model`](https://www.tensorflow.org/api_docs
... ) # doctest: +SKIP
```
5. When you're ready, you can call `compile` and `fit` to start training:
5. When you're ready, you can call `compile` and `fit` to start training. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> from tensorflow.keras.optimizers import Adam
>>> model.compile(optimizer=Adam(3e-5))
>>> model.compile(optimizer=Adam(3e-5)) # No loss argument!
>>> model.fit(tf_dataset) # doctest: +SKIP
```

View File

@ -306,12 +306,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:

View File

@ -301,12 +301,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:

View File

@ -335,10 +335,10 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).

View File

@ -377,7 +377,7 @@ Start by defining the hyperparameters, optimizer and learning rate schedule:
```
Then, load SegFormer with [`TFAutoModelForSemanticSegmentation`] along with the label mappings, and compile it with the
optimizer:
optimizer. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> from transformers import TFAutoModelForSemanticSegmentation
@ -387,7 +387,7 @@ optimizer:
... id2label=id2label,
... label2id=label2id,
... )
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Dataset.to_tf_dataset`] and the [`DefaultDataCollator`]:

View File

@ -259,12 +259,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).

View File

@ -267,12 +267,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
The last two things to setup before you start training is to compute the ROUGE score from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).

View File

@ -361,12 +361,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
The last two things to setup before you start training is to compute the seqeval scores from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).

View File

@ -276,12 +276,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
... )
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
>>> model.compile(optimizer=optimizer) # No loss argument!
```
The last two things to setup before you start training is to compute the SacreBLEU metric from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).

View File

@ -191,7 +191,7 @@ tokenized_data = dict(tokenized_data)
labels = np.array(dataset["label"]) # Label is already an array of 0 and 1
```
Finally, load, [`compile`](https://keras.io/api/models/model_training_apis/#compile-method), and [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) the model:
Finally, load, [`compile`](https://keras.io/api/models/model_training_apis/#compile-method), and [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) the model. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
```py
from transformers import TFAutoModelForSequenceClassification
@ -200,7 +200,7 @@ from tensorflow.keras.optimizers import Adam
# Load and compile our model
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
# Lower learning rates are often better for fine-tuning transformers
model.compile(optimizer=Adam(3e-5))
model.compile(optimizer=Adam(3e-5)) # No loss argument!
model.fit(tokenized_data, labels)
```
@ -261,7 +261,7 @@ list of samples into a batch and apply any preprocessing you want. See our
Once you've created a `tf.data.Dataset`, you can compile and fit the model as before:
```py
model.compile(optimizer=Adam(3e-5))
model.compile(optimizer=Adam(3e-5)) # No loss argument!
model.fit(tf_dataset)
```

View File

@ -561,6 +561,8 @@ def main():
weight_decay_rate=training_args.weight_decay,
adam_global_clipnorm=training_args.max_grad_norm,
)
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
if not training_args.do_eval:

View File

@ -497,6 +497,8 @@ def main():
collate_fn=collate_fn,
).with_options(dataset_options)
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
push_to_hub_model_id = training_args.push_to_hub_model_id

View File

@ -235,8 +235,10 @@ def main(args):
num_warmup_steps=total_train_steps // 20,
init_lr=args.learning_rate,
weight_decay_rate=args.weight_decay_rate,
# TODO Add the other Adam parameters?
)
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, metrics=["accuracy"])
def decode_fn(example):

View File

@ -537,7 +537,8 @@ def main():
adam_global_clipnorm=training_args.max_grad_norm,
)
# no user-specified loss = will use the model internal loss
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
# endregion

View File

@ -559,8 +559,9 @@ def main():
adam_global_clipnorm=training_args.max_grad_norm,
)
# no user-specified loss = will use the model internal loss
model.compile(optimizer=optimizer, jit_compile=training_args.xla, run_eagerly=True)
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
# endregion
# region Preparing push_to_hub and model card

View File

@ -455,6 +455,8 @@ def main():
)
else:
optimizer = None
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, metrics=["accuracy"], jit_compile=training_args.xla)
# endregion

View File

@ -656,7 +656,8 @@ def main():
adam_global_clipnorm=training_args.max_grad_norm,
)
# no user-specified loss = will use the model internal loss
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
else:

View File

@ -674,6 +674,8 @@ def main():
# endregion
# region Training
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
eval_metrics = None
if training_args.do_train:

View File

@ -453,6 +453,8 @@ def main():
metrics = []
else:
metrics = ["accuracy"]
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, metrics=metrics, jit_compile=training_args.xla)
# endregion

View File

@ -487,6 +487,8 @@ def main():
metrics = []
else:
metrics = ["accuracy"]
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, metrics=metrics)
# endregion

View File

@ -454,7 +454,8 @@ def main():
weight_decay_rate=training_args.weight_decay,
adam_global_clipnorm=training_args.max_grad_norm,
)
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
# endregion

View File

@ -643,6 +643,8 @@ def main():
# region Training
eval_metrics = None
# Transformers models compute the right loss for their task by default when labels are passed, and will
# use this for training unless you specify your own loss function in compile().
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
if training_args.do_train:

View File

@ -1498,7 +1498,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
def compile(
self,
optimizer="rmsprop",
loss="passthrough",
loss="auto_with_warning",
metrics=None,
loss_weights=None,
weighted_metrics=None,
@ -1510,13 +1510,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
function themselves.
"""
if loss == "passthrough":
logger.warning(
if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility
logger.info(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"To disable this behaviour please pass a loss argument, or explicitly pass "
"`loss=None` if you do not want your model to compute a loss."
"`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
"get the internal loss without printing this info string."
)
loss = "auto"
if loss == "auto":
loss = dummy_loss
self._using_dummy_loss = True
else: