mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix KerasMetricCallback
: pass generate_kwargs
even if use_xla_generation
is False (#24333)
* Fix `KerasMetricCallback`: always pass `generate_kwargs`. * Reformat code using Black.
This commit is contained in:
parent
0b259a3b7e
commit
3b5a56e595
@ -224,7 +224,9 @@ class KerasMetricCallback(Callback):
|
||||
if self.use_xla_generation:
|
||||
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
|
||||
else:
|
||||
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
|
||||
predictions = self.model.generate(
|
||||
generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
|
||||
)
|
||||
else:
|
||||
predictions = self.model.predict_on_batch(batch)
|
||||
if isinstance(predictions, dict):
|
||||
|
Loading…
Reference in New Issue
Block a user