Correctly cast num_train_epochs to int (#11379)

This commit is contained in:
Matt 2021-04-22 13:49:59 +01:00 committed by GitHub
parent 881945c0b5
commit 2617396094
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -492,7 +492,10 @@ def main():
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
model.fit(
training_dataset, validation_data=eval_dataset, epochs=training_args.num_train_epochs, callbacks=callbacks
training_dataset,
validation_data=eval_dataset,
epochs=int(training_args.num_train_epochs),
callbacks=callbacks,
)
elif "validation" in datasets:
# If there's a validation dataset but no training set, just evaluate the metrics