mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Make _test_xla_generate
less flaky (#22996)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a0e7332839
commit
cf7baf4060
@ -1868,7 +1868,18 @@ class TFModelTesterMixin:
|
||||
generated = model.generate(inputs, **generate_kwargs).numpy()
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
|
||||
self.assertListEqual(generated.tolist(), generated_xla.tolist())
|
||||
|
||||
# Due to numerical instability, let's fail the test only if there are more than 10% of input sequences give
|
||||
# different outputs between XLA and non-XLA versions. If there are less than 10 examples, let's be strict
|
||||
# and not allow any difference.
|
||||
diff = [[], []]
|
||||
for _generated, _generated_xla in zip(generated.tolist(), generated_xla.tolist()):
|
||||
if _generated != _generated_xla:
|
||||
diff[0].append(_generated)
|
||||
diff[1].append(_generated_xla)
|
||||
ratio = len(diff[0]) / len(generated)
|
||||
if ratio > 0.1 or (len(diff[0]) > 0 and len(generated) < 10):
|
||||
self.assertListEqual(diff[0], diff[1])
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user