mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
Add an option to reduce compile() console spam (#23938)
* Add an option to reduce compile() console spam * Add annotations to the example scripts * Add notes to the quicktour docs as well * minor fix
This commit is contained in:
parent
c9cf337772
commit
167a0d8f87
@ -532,12 +532,12 @@ All models are a standard [`tf.keras.Model`](https://www.tensorflow.org/api_docs
|
|||||||
... ) # doctest: +SKIP
|
... ) # doctest: +SKIP
|
||||||
```
|
```
|
||||||
|
|
||||||
5. When you're ready, you can call `compile` and `fit` to start training:
|
5. When you're ready, you can call `compile` and `fit` to start training. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from tensorflow.keras.optimizers import Adam
|
>>> from tensorflow.keras.optimizers import Adam
|
||||||
|
|
||||||
>>> model.compile(optimizer=Adam(3e-5))
|
>>> model.compile(optimizer=Adam(3e-5)) # No loss argument!
|
||||||
>>> model.fit(tf_dataset) # doctest: +SKIP
|
>>> model.fit(tf_dataset) # doctest: +SKIP
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -306,12 +306,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
|
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
|
||||||
|
@ -301,12 +301,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
|
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
|
||||||
|
@ -335,10 +335,10 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
||||||
|
@ -377,7 +377,7 @@ Start by defining the hyperparameters, optimizer and learning rate schedule:
|
|||||||
```
|
```
|
||||||
|
|
||||||
Then, load SegFormer with [`TFAutoModelForSemanticSegmentation`] along with the label mappings, and compile it with the
|
Then, load SegFormer with [`TFAutoModelForSemanticSegmentation`] along with the label mappings, and compile it with the
|
||||||
optimizer:
|
optimizer. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> from transformers import TFAutoModelForSemanticSegmentation
|
>>> from transformers import TFAutoModelForSemanticSegmentation
|
||||||
@ -387,7 +387,7 @@ optimizer:
|
|||||||
... id2label=id2label,
|
... id2label=id2label,
|
||||||
... label2id=label2id,
|
... label2id=label2id,
|
||||||
... )
|
... )
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Dataset.to_tf_dataset`] and the [`DefaultDataCollator`]:
|
Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Dataset.to_tf_dataset`] and the [`DefaultDataCollator`]:
|
||||||
|
@ -259,12 +259,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
The last two things to setup before you start training is to compute the accuracy from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
||||||
|
@ -267,12 +267,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
The last two things to setup before you start training is to compute the ROUGE score from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
The last two things to setup before you start training is to compute the ROUGE score from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
||||||
|
@ -361,12 +361,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
The last two things to setup before you start training is to compute the seqeval scores from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
The last two things to setup before you start training is to compute the seqeval scores from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
||||||
|
@ -276,12 +276,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPre
|
|||||||
... )
|
... )
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> import tensorflow as tf
|
>>> import tensorflow as tf
|
||||||
|
|
||||||
>>> model.compile(optimizer=optimizer)
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||||
```
|
```
|
||||||
|
|
||||||
The last two things to setup before you start training is to compute the SacreBLEU metric from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
The last two things to setup before you start training is to compute the SacreBLEU metric from the predictions, and provide a way to push your model to the Hub. Both are done by using [Keras callbacks](../main_classes/keras_callbacks).
|
||||||
|
@ -191,7 +191,7 @@ tokenized_data = dict(tokenized_data)
|
|||||||
labels = np.array(dataset["label"]) # Label is already an array of 0 and 1
|
labels = np.array(dataset["label"]) # Label is already an array of 0 and 1
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, load, [`compile`](https://keras.io/api/models/model_training_apis/#compile-method), and [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) the model:
|
Finally, load, [`compile`](https://keras.io/api/models/model_training_apis/#compile-method), and [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) the model. Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers import TFAutoModelForSequenceClassification
|
from transformers import TFAutoModelForSequenceClassification
|
||||||
@ -200,7 +200,7 @@ from tensorflow.keras.optimizers import Adam
|
|||||||
# Load and compile our model
|
# Load and compile our model
|
||||||
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
||||||
# Lower learning rates are often better for fine-tuning transformers
|
# Lower learning rates are often better for fine-tuning transformers
|
||||||
model.compile(optimizer=Adam(3e-5))
|
model.compile(optimizer=Adam(3e-5)) # No loss argument!
|
||||||
|
|
||||||
model.fit(tokenized_data, labels)
|
model.fit(tokenized_data, labels)
|
||||||
```
|
```
|
||||||
@ -261,7 +261,7 @@ list of samples into a batch and apply any preprocessing you want. See our
|
|||||||
Once you've created a `tf.data.Dataset`, you can compile and fit the model as before:
|
Once you've created a `tf.data.Dataset`, you can compile and fit the model as before:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
model.compile(optimizer=Adam(3e-5))
|
model.compile(optimizer=Adam(3e-5)) # No loss argument!
|
||||||
|
|
||||||
model.fit(tf_dataset)
|
model.fit(tf_dataset)
|
||||||
```
|
```
|
||||||
|
@ -561,6 +561,8 @@ def main():
|
|||||||
weight_decay_rate=training_args.weight_decay,
|
weight_decay_rate=training_args.weight_decay,
|
||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
|
|
||||||
if not training_args.do_eval:
|
if not training_args.do_eval:
|
||||||
|
@ -497,6 +497,8 @@ def main():
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
).with_options(dataset_options)
|
).with_options(dataset_options)
|
||||||
|
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
||||||
|
|
||||||
push_to_hub_model_id = training_args.push_to_hub_model_id
|
push_to_hub_model_id = training_args.push_to_hub_model_id
|
||||||
|
@ -235,8 +235,10 @@ def main(args):
|
|||||||
num_warmup_steps=total_train_steps // 20,
|
num_warmup_steps=total_train_steps // 20,
|
||||||
init_lr=args.learning_rate,
|
init_lr=args.learning_rate,
|
||||||
weight_decay_rate=args.weight_decay_rate,
|
weight_decay_rate=args.weight_decay_rate,
|
||||||
# TODO Add the other Adam parameters?
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, metrics=["accuracy"])
|
model.compile(optimizer=optimizer, metrics=["accuracy"])
|
||||||
|
|
||||||
def decode_fn(example):
|
def decode_fn(example):
|
||||||
|
@ -537,7 +537,8 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# no user-specified loss = will use the model internal loss
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -559,8 +559,9 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# no user-specified loss = will use the model internal loss
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla, run_eagerly=True)
|
# use this for training unless you specify your own loss function in compile().
|
||||||
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Preparing push_to_hub and model card
|
# region Preparing push_to_hub and model card
|
||||||
|
@ -455,6 +455,8 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = None
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, metrics=["accuracy"], jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, metrics=["accuracy"], jit_compile=training_args.xla)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -656,7 +656,8 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# no user-specified loss = will use the model internal loss
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -674,6 +674,8 @@ def main():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Training
|
# region Training
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
eval_metrics = None
|
eval_metrics = None
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
@ -453,6 +453,8 @@ def main():
|
|||||||
metrics = []
|
metrics = []
|
||||||
else:
|
else:
|
||||||
metrics = ["accuracy"]
|
metrics = ["accuracy"]
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, metrics=metrics, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, metrics=metrics, jit_compile=training_args.xla)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -487,6 +487,8 @@ def main():
|
|||||||
metrics = []
|
metrics = []
|
||||||
else:
|
else:
|
||||||
metrics = ["accuracy"]
|
metrics = ["accuracy"]
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, metrics=metrics)
|
model.compile(optimizer=optimizer, metrics=metrics)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -454,7 +454,8 @@ def main():
|
|||||||
weight_decay_rate=training_args.weight_decay,
|
weight_decay_rate=training_args.weight_decay,
|
||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -643,6 +643,8 @@ def main():
|
|||||||
|
|
||||||
# region Training
|
# region Training
|
||||||
eval_metrics = None
|
eval_metrics = None
|
||||||
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
@ -1498,7 +1498,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
optimizer="rmsprop",
|
optimizer="rmsprop",
|
||||||
loss="passthrough",
|
loss="auto_with_warning",
|
||||||
metrics=None,
|
metrics=None,
|
||||||
loss_weights=None,
|
loss_weights=None,
|
||||||
weighted_metrics=None,
|
weighted_metrics=None,
|
||||||
@ -1510,13 +1510,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
|
This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
|
||||||
function themselves.
|
function themselves.
|
||||||
"""
|
"""
|
||||||
if loss == "passthrough":
|
if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility
|
||||||
logger.warning(
|
logger.info(
|
||||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||||
"To disable this behaviour please pass a loss argument, or explicitly pass "
|
"To disable this behaviour please pass a loss argument, or explicitly pass "
|
||||||
"`loss=None` if you do not want your model to compute a loss."
|
"`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
|
||||||
|
"get the internal loss without printing this info string."
|
||||||
)
|
)
|
||||||
|
loss = "auto"
|
||||||
|
if loss == "auto":
|
||||||
loss = dummy_loss
|
loss = dummy_loss
|
||||||
self._using_dummy_loss = True
|
self._using_dummy_loss = True
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user