diff --git a/examples/tensorflow/text-classification/run_text_classification.py b/examples/tensorflow/text-classification/run_text_classification.py index 861a9ccc3c0..3c9e2600970 100644 --- a/examples/tensorflow/text-classification/run_text_classification.py +++ b/examples/tensorflow/text-classification/run_text_classification.py @@ -492,7 +492,10 @@ def main(): callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)] model.fit( - training_dataset, validation_data=eval_dataset, epochs=training_args.num_train_epochs, callbacks=callbacks + training_dataset, + validation_data=eval_dataset, + epochs=int(training_args.num_train_epochs), + callbacks=callbacks, ) elif "validation" in datasets: # If there's a validation dataset but no training set, just evaluate the metrics