Fix prepare_tf_dataset when drop_remainder is not supplied (#17950)

This commit is contained in:
Matt 2022-06-29 19:23:39 +01:00 committed by GitHub
parent bc019b0e5f
commit 5feac3d080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1215,6 +1215,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
label_cols = [col for col in output_columns if col in model_labels]
if drop_remainder is None:
drop_remainder = shuffle
tf_dataset = dataset.to_tf_dataset(
columns=feature_cols,
label_cols=label_cols,