Fix failing tests for XLA generation in TF (#18298)

* Fix failing test_xla_generate_slow tests

* Fix failing speech-to-text xla_generate tests
This commit is contained in:
Daniel Suess 2022-08-03 15:45:15 +02:00 committed by GitHub
parent a507908cd3
commit 8fb7c908c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1685,6 +1685,17 @@ class TFModelTesterMixin:
config.do_sample = False config.do_sample = False
config.num_beams = num_beams config.num_beams = num_beams
config.num_return_sequences = num_return_sequences config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config) model = model_class(config)
if model.supports_xla_generation: if model.supports_xla_generation:
@ -1714,15 +1725,6 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
# TODO (Joao): find the issues related to the following models. They are passing the fast test, but failing
# the slow one.
if any(
[
model in str(self).lower()
for model in ["tfbart", "tfblenderbot", "tfmarian", "tfmbart", "tfopt", "tfpegasus"]
]
):
return
num_beams = 8 num_beams = 8
num_return_sequences = 2 num_return_sequences = 2
max_length = 128 max_length = 128