mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Fix failing tests for XLA generation in TF (#18298)
* Fix failing test_xla_generate_slow tests * Fix failing speech-to-text xla_generate tests
This commit is contained in:
parent
a507908cd3
commit
8fb7c908c8
@ -1685,6 +1685,17 @@ class TFModelTesterMixin:
|
|||||||
config.do_sample = False
|
config.do_sample = False
|
||||||
config.num_beams = num_beams
|
config.num_beams = num_beams
|
||||||
config.num_return_sequences = num_return_sequences
|
config.num_return_sequences = num_return_sequences
|
||||||
|
|
||||||
|
# fix config for models with additional sequence-length limiting settings
|
||||||
|
for var_name in ["max_position_embeddings", "max_target_positions"]:
|
||||||
|
if hasattr(config, var_name):
|
||||||
|
try:
|
||||||
|
setattr(config, var_name, max_length)
|
||||||
|
except NotImplementedError:
|
||||||
|
# xlnet will raise an exception when trying to set
|
||||||
|
# max_position_embeddings.
|
||||||
|
pass
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
if model.supports_xla_generation:
|
if model.supports_xla_generation:
|
||||||
@ -1714,15 +1725,6 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||||
"""
|
"""
|
||||||
# TODO (Joao): find the issues related to the following models. They are passing the fast test, but failing
|
|
||||||
# the slow one.
|
|
||||||
if any(
|
|
||||||
[
|
|
||||||
model in str(self).lower()
|
|
||||||
for model in ["tfbart", "tfblenderbot", "tfmarian", "tfmbart", "tfopt", "tfpegasus"]
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return
|
|
||||||
num_beams = 8
|
num_beams = 8
|
||||||
num_return_sequences = 2
|
num_return_sequences = 2
|
||||||
max_length = 128
|
max_length = 128
|
||||||
|
Loading…
Reference in New Issue
Block a user