[CI] green llama tests (#37244)

* green llama tests

* use cleanup instead

* better test comment; cleanup upgrade

* better test comment; cleanup upgrade
This commit is contained in:
Joao Gante 2025-04-03 14:15:53 +01:00 committed by GitHub
parent 782d7d945d
commit 9a1c1fe7ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 62 additions and 36 deletions

View File

@ -118,7 +118,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
with torch.no_grad(): with torch.no_grad():
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate, max_cache_len=seq_length + num_tokens_to_generate,
@ -144,7 +144,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate, max_cache_len=seq_length + num_tokens_to_generate,
@ -187,7 +187,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification # TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate + 10, max_cache_len=seq_length + num_tokens_to_generate + 10,
@ -254,7 +254,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + 128, max_cache_len=seq_length + 128,
@ -271,7 +271,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + 128, max_cache_len=seq_length + 128,
@ -287,7 +287,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + 128, max_cache_len=seq_length + 128,
@ -303,7 +303,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
past_key_values = StaticCache( past_key_values = StaticCache(
model.config, model.config,
batch_size=batch_size, max_batch_size=batch_size,
device=device, device=device,
dtype=torch.float16, dtype=torch.float16,
max_cache_len=seq_length + 128, max_cache_len=seq_length + 128,

View File

@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, config=model.config,
batch_size=1, max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases # If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device, device=model.device,
@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
batch_size, seq_length = inputs["input_ids"].shape batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad(): with torch.no_grad():
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
) )
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros( generated_ids = torch.zeros(

View File

@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, config=model.config,
batch_size=1, max_batch_size=1,
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다 # 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device, device=model.device,

View File

@ -3075,6 +3075,7 @@ def cleanup(device: str, gc_collect=False):
if gc_collect: if gc_collect:
gc.collect() gc.collect()
backend_empty_cache(device) backend_empty_cache(device)
torch._dynamo.reset()
# Type definition of key used in `Expectations` class. # Type definition of key used in `Expectations` class.

View File

@ -2285,6 +2285,14 @@ class GenerationTesterMixin:
inputs_dict[input_name] = input_data inputs_dict[input_name] = input_data
main_input = inputs_dict[model_class.main_input_name] main_input = inputs_dict[model_class.main_input_name]
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
# resulting in a mask with holes (not supported properly by FA2).
if attn_implementation == "flash_attention_2":
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
if input_name in inputs_dict:
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
# make sure that all models have enough positions for generation # make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"): if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
@ -2339,8 +2347,6 @@ class GenerationTesterMixin:
@slow @slow
def test_eager_matches_fa2_generate(self): def test_eager_matches_fa2_generate(self):
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" """Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
# check whether we still need the overwrites
self._test_attention_implementation("flash_attention_2") self._test_attention_implementation("flash_attention_2")
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
@ -3974,7 +3980,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# TODO: We need to raise a warning in case the cache is not set correctly # TODO: We need to raise a warning in case the cache is not set correctly
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
# past_key_values = StaticCache( # past_key_values = StaticCache(
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype # config=model.config, max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
# ) # )
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
@ -3982,7 +3988,7 @@ class GenerationIntegrationTests(unittest.TestCase):
layer_device_map = {0: 0, 1: 1} layer_device_map = {0: 0, 1: 1}
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, config=model.config,
batch_size=1, max_batch_size=1,
max_cache_len=30, max_cache_len=30,
device=torch_device, device=torch_device,
dtype=model.dtype, dtype=model.dtype,
@ -4183,7 +4189,11 @@ class GenerationIntegrationTests(unittest.TestCase):
batch_size = 2 batch_size = 2
query_length = input_ids.shape[-1] - init_input_ids.shape[-1] query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
static_cache = StaticCache( static_cache = StaticCache(
config=config, batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32 config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=torch_device,
dtype=torch.float32,
) )
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values
model_inputs = model.prepare_inputs_for_generation( model_inputs = model.prepare_inputs_for_generation(

View File

@ -21,6 +21,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_read_token, require_read_token,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
@ -605,6 +606,10 @@ class DeepseekV3IntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4 # 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
cleanup(torch_device, gc_collect=False)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
@require_read_token @require_read_token

View File

@ -25,6 +25,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, backend_empty_cache,
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_read_token, require_read_token,
@ -685,6 +686,10 @@ class DiffLlamaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4 # 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
cleanup(torch_device, gc_collect=False)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
@require_read_token @require_read_token
@ -884,7 +889,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
batch_size=1, max_batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,
@ -932,7 +937,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
batch_size=1, max_batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@ -23,6 +23,7 @@ from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
is_flaky, is_flaky,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
@ -498,6 +499,10 @@ class GemmaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4 # 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
cleanup(torch_device, gc_collect=False)
@require_read_token @require_read_token
def test_model_2b_fp16(self): def test_model_2b_fp16(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"

View File

@ -549,6 +549,13 @@ class LlamaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4 # 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
# Investigate the root cause.
cleanup(torch_device, gc_collect=False)
@slow @slow
@require_read_token @require_read_token
def test_llama_3_1_hard(self): def test_llama_3_1_hard(self):
@ -748,14 +755,6 @@ class LlamaIntegrationTest(unittest.TestCase):
"Simply put, the theory of relativity states that 1) the speed of light is the same for all " "Simply put, the theory of relativity states that 1) the speed of light is the same for all "
"observers, regardless of their location, and 2) the laws of physics are the same for all observers" "observers, regardless of their location, and 2) the laws of physics are the same for all observers"
], ],
"meta-llama/Llama-3.2-3B": [
"Simply put, the theory of relativity states that 1. the speed of light is constant, and 2. "
"the speed of light is the fastest speed possible"
],
"meta-llama/Llama-2-7b-hf": [
"Simply put, the theory of relativity states that 1) the speed of light is a constant, and 2) "
"the laws of physics are the same for all",
],
} }
for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items(): for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items():
@ -946,7 +945,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
batch_size=1, max_batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,
@ -994,7 +993,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
batch_size=1, max_batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@ -53,7 +53,7 @@ if is_torch_available():
self.model = model self.model = model
self.cache = StaticCache( self.cache = StaticCache(
config=model.config, config=model.config,
batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=max_seq_len, max_cache_len=max_seq_len,
device=self.model.device, device=self.model.device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@ -227,10 +227,6 @@ class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
def test_flash_attn_2_inference_equivalence_right_padding(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
pass pass
@unittest.skip(reason="This one tries to use right padding as well")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(reason="Depending on input modalities, some params may not have gradients") @unittest.skip(reason="Depending on input modalities, some params may not have gradients")
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass

View File

@ -52,7 +52,7 @@ if is_torch_available():
self.model = model self.model = model
self.cache = StaticCache( self.cache = StaticCache(
config=model.config, config=model.config,
batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=max_seq_len, max_cache_len=max_seq_len,
device=self.model.device, device=self.model.device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@ -24,6 +24,7 @@ from transformers import T5Config, is_torch_available
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_accelerate, require_accelerate,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
@ -1170,6 +1171,10 @@ class T5ModelFp16Tests(unittest.TestCase):
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class T5ModelIntegrationTests(unittest.TestCase): class T5ModelIntegrationTests(unittest.TestCase):
def tearDown(self):
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
cleanup(torch_device, gc_collect=False)
@cached_property @cached_property
def model(self): def model(self):
return T5ForConditionalGeneration.from_pretrained("google-t5/t5-base").to(torch_device) return T5ForConditionalGeneration.from_pretrained("google-t5/t5-base").to(torch_device)

View File

@ -226,7 +226,7 @@ class AqlmTest(unittest.TestCase):
# Setup static KV cache for generation # Setup static KV cache for generation
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.quantized_model.config, config=self.quantized_model.config,
batch_size=1, max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1, max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device, device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype, dtype=self.quantized_model.config._pre_quantization_dtype,

View File

@ -207,7 +207,7 @@ class SpQRTest(unittest.TestCase):
# Setup static KV cache for generation # Setup static KV cache for generation
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.quantized_model.config, config=self.quantized_model.config,
batch_size=1, max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1, max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device, device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype, dtype=self.quantized_model.config._pre_quantization_dtype,