diff --git a/benchmark/llama.py b/benchmark/llama.py index bc91b29b581..1857dee3d66 100644 --- a/benchmark/llama.py +++ b/benchmark/llama.py @@ -118,7 +118,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, with torch.no_grad(): past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate, @@ -144,7 +144,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate, @@ -187,7 +187,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, # TODO use decode_one_token(model, input_id.clone(), cache_position) for verification past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + num_tokens_to_generate + 10, @@ -254,7 +254,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -271,7 +271,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -287,7 +287,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, @@ -303,7 +303,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, past_key_values = StaticCache( model.config, - batch_size=batch_size, + max_batch_size=batch_size, device=device, dtype=torch.float16, max_cache_len=seq_length + 128, diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 7c9bc154ab6..e8e20dab5db 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - batch_size=1, + max_batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, @@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/docs/source/ko/llm_optims.md b/docs/source/ko/llm_optims.md index 99eabc19ce8..f6eaa58c000 100644 --- a/docs/source/ko/llm_optims.md +++ b/docs/source/ko/llm_optims.md @@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - batch_size=1, + max_batch_size=1, # 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다 max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, @@ -109,7 +109,7 @@ outputs = model.generate(**input_ids, past_key_values=past_key_values) print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2'] -# 생성된 텍스트와 동일한 캐시 객체를 전달하여, 중단한 곳에서 생성을 계속합니다. +# 생성된 텍스트와 동일한 캐시 객체를 전달하여, 중단한 곳에서 생성을 계속합니다. # 다중 턴 대화의 경우, 생성된 텍스트에 새로운 사용자 입력을 추가할 수 있습니다. new_input_ids = outputs outputs = model.generate(new_input_ids, past_key_values=past_key_values) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 62409ff70d2..00fc9504242 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -3075,6 +3075,7 @@ def cleanup(device: str, gc_collect=False): if gc_collect: gc.collect() backend_empty_cache(device) + torch._dynamo.reset() # Type definition of key used in `Expectations` class. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f61b26c26ec..e726ecbd507 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2285,6 +2285,14 @@ class GenerationTesterMixin: inputs_dict[input_name] = input_data main_input = inputs_dict[model_class.main_input_name] + # FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded + # attention masks at test time and, with generate, the mask will be appended with 1s on the right, + # resulting in a mask with holes (not supported properly by FA2). + if attn_implementation == "flash_attention_2": + for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"): + if input_name in inputs_dict: + inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name]) + # make sure that all models have enough positions for generation if hasattr(config, "max_position_embeddings"): config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 @@ -2339,8 +2347,6 @@ class GenerationTesterMixin: @slow def test_eager_matches_fa2_generate(self): """Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" - # TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing, - # check whether we still need the overwrites self._test_attention_implementation("flash_attention_2") def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): @@ -3974,7 +3980,7 @@ class GenerationIntegrationTests(unittest.TestCase): # TODO: We need to raise a warning in case the cache is not set correctly # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): # past_key_values = StaticCache( - # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype + # config=model.config, max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype # ) # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) @@ -3982,7 +3988,7 @@ class GenerationIntegrationTests(unittest.TestCase): layer_device_map = {0: 0, 1: 1} past_key_values = StaticCache( config=model.config, - batch_size=1, + max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype, @@ -4183,7 +4189,11 @@ class GenerationIntegrationTests(unittest.TestCase): batch_size = 2 query_length = input_ids.shape[-1] - init_input_ids.shape[-1] static_cache = StaticCache( - config=config, batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32 + config=config, + max_batch_size=batch_size, + max_cache_len=max_cache_len, + device=torch_device, + dtype=torch.float32, ) static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values model_inputs = model.prepare_inputs_for_generation( diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index f452b028a20..be37d3cd175 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -21,6 +21,7 @@ from parameterized import parameterized from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed from transformers.testing_utils import ( + cleanup, require_read_token, require_torch, require_torch_accelerator, @@ -605,6 +606,10 @@ class DeepseekV3IntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def tearDown(self): + # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. + cleanup(torch_device, gc_collect=False) + @slow @require_torch_accelerator @require_read_token diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index aba1844e343..08bd0295539 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -25,6 +25,7 @@ from parameterized import parameterized from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( backend_empty_cache, + cleanup, require_bitsandbytes, require_flash_attn, require_read_token, @@ -685,6 +686,10 @@ class DiffLlamaIntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def tearDown(self): + # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. + cleanup(torch_device, gc_collect=False) + @slow @require_torch_accelerator @require_read_token @@ -884,7 +889,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - batch_size=1, + max_batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, @@ -932,7 +937,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - batch_size=1, + max_batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 20247bd68c0..84802395b5c 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -23,6 +23,7 @@ from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( + cleanup, is_flaky, require_bitsandbytes, require_flash_attn, @@ -498,6 +499,10 @@ class GemmaIntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def tearDown(self): + # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. + cleanup(torch_device, gc_collect=False) + @require_read_token def test_model_2b_fp16(self): model_id = "google/gemma-2b" diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index b072105a3fa..7f7930fd8af 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -549,6 +549,13 @@ class LlamaIntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=False) + @slow @require_read_token def test_llama_3_1_hard(self): @@ -748,14 +755,6 @@ class LlamaIntegrationTest(unittest.TestCase): "Simply put, the theory of relativity states that 1) the speed of light is the same for all " "observers, regardless of their location, and 2) the laws of physics are the same for all observers" ], - "meta-llama/Llama-3.2-3B": [ - "Simply put, the theory of relativity states that 1. the speed of light is constant, and 2. " - "the speed of light is the fastest speed possible" - ], - "meta-llama/Llama-2-7b-hf": [ - "Simply put, the theory of relativity states that 1) the speed of light is a constant, and 2) " - "the laws of physics are the same for all", - ], } for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items(): @@ -946,7 +945,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - batch_size=1, + max_batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, @@ -994,7 +993,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - batch_size=1, + max_batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 2edf52db3aa..666da7d2b6b 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -53,7 +53,7 @@ if is_torch_available(): self.model = model self.cache = StaticCache( config=model.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=max_seq_len, device=self.model.device, dtype=self.model.dtype, diff --git a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py index 737e712a34f..b150116818c 100644 --- a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py +++ b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py @@ -227,10 +227,6 @@ class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. def test_flash_attn_2_inference_equivalence_right_padding(self): pass - @unittest.skip(reason="This one tries to use right padding as well") - def test_eager_matches_fa2_generate(self): - pass - @unittest.skip(reason="Depending on input modalities, some params may not have gradients") def test_training_gradient_checkpointing(self): pass diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index b3dc1eba682..0277335c830 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -52,7 +52,7 @@ if is_torch_available(): self.model = model self.cache = StaticCache( config=model.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=max_seq_len, device=self.model.device, dtype=self.model.dtype, diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 48fe5e8942a..ce543e60661 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -24,6 +24,7 @@ from transformers import T5Config, is_torch_available from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( + cleanup, require_accelerate, require_sentencepiece, require_tokenizers, @@ -1170,6 +1171,10 @@ class T5ModelFp16Tests(unittest.TestCase): @require_sentencepiece @require_tokenizers class T5ModelIntegrationTests(unittest.TestCase): + def tearDown(self): + # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. + cleanup(torch_device, gc_collect=False) + @cached_property def model(self): return T5ForConditionalGeneration.from_pretrained("google-t5/t5-base").to(torch_device) diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 8195d975711..d5d45f43cc3 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -226,7 +226,7 @@ class AqlmTest(unittest.TestCase): # Setup static KV cache for generation past_key_values = StaticCache( config=self.quantized_model.config, - batch_size=1, + max_batch_size=1, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=self.quantized_model.config._pre_quantization_dtype, diff --git a/tests/quantization/spqr_integration/test_spqr.py b/tests/quantization/spqr_integration/test_spqr.py index 134e57af5de..425cce664c0 100644 --- a/tests/quantization/spqr_integration/test_spqr.py +++ b/tests/quantization/spqr_integration/test_spqr.py @@ -207,7 +207,7 @@ class SpQRTest(unittest.TestCase): # Setup static KV cache for generation past_key_values = StaticCache( config=self.quantized_model.config, - batch_size=1, + max_batch_size=1, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=self.quantized_model.config._pre_quantization_dtype,