mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix flaky test vision encoder-decoder generate (#28923)
This commit is contained in:
parent
0507e69d34
commit
354775bc57
@ -23,7 +23,6 @@ from packaging import version
|
|||||||
|
|
||||||
from transformers import DonutProcessor, NougatProcessor, TrOCRProcessor
|
from transformers import DonutProcessor, NougatProcessor, TrOCRProcessor
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
|
||||||
require_levenshtein,
|
require_levenshtein,
|
||||||
require_nltk,
|
require_nltk,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
@ -286,6 +285,8 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.config.eos_token_id = None
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||||
|
enc_dec_model.generation_config.eos_token_id = None
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
inputs = pixel_values
|
inputs = pixel_values
|
||||||
@ -324,10 +325,6 @@ class EncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||||
|
|
||||||
# FIXME @gante: flaky test
|
|
||||||
@is_flaky(
|
|
||||||
description="Fails on distributed runs e.g.: https://app.circleci.com/pipelines/github/huggingface/transformers/83611/workflows/666b01c9-1be8-4daa-b85d-189e670fc168/jobs/1078635/tests#failed-test-0"
|
|
||||||
)
|
|
||||||
def test_encoder_decoder_model_generate(self):
|
def test_encoder_decoder_model_generate(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
Loading…
Reference in New Issue
Block a user