Reduce console spam when using the KerasMetricCallback (#18202)

* Reduce console spam when using the KerasMetricCallback

* Switch to predict_on_batch to improve performance
This commit is contained in:
Matt 2022-07-19 12:00:35 -04:00 committed by GitHub
parent ec6cd7633f
commit 8a61fe0234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -202,7 +202,7 @@ class KerasMetricCallback(Callback):
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.predict(batch)
predictions = self.model.predict_on_batch(batch)
if isinstance(predictions, dict):
# This converts any dict-subclass to a regular dict
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class