Fix tests failed with gated repos. (#37484)

* fix

* slow

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-04-14 12:08:13 +02:00 committed by GitHub
parent 1ef64710d2
commit ac1df5fccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,7 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_flash_attn, require_flash_attn,
require_optimum_quanto, require_optimum_quanto,
require_read_token,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
@ -4283,6 +4284,8 @@ class GenerationIntegrationTests(unittest.TestCase):
gen_out = compiled_generate(**model_inputs, generation_config=generation_config) gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
@require_read_token
@slow
def test_assisted_generation_early_exit(self): def test_assisted_generation_early_exit(self):
""" """
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
@ -4791,6 +4794,7 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids)) self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input)) self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
@require_read_token
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_cache_device_map_with_vision_layer_device_map(self): def test_cache_device_map_with_vision_layer_device_map(self):