Fix KerasMetricCallback: pass generate_kwargs even if use_xla_generation is False (#24333)

* Fix `KerasMetricCallback`: always pass `generate_kwargs`.

* Reformat code using Black.
This commit is contained in:
Matěj Kripner 2023-06-19 13:51:25 +02:00 committed by GitHub
parent 0b259a3b7e
commit 3b5a56e595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -224,7 +224,9 @@ class KerasMetricCallback(Callback):
if self.use_xla_generation:
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
predictions = self.model.generate(
generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
)
else:
predictions = self.model.predict_on_batch(batch)
if isinstance(predictions, dict):