[cache] add a test to confirm we can use cache at train time (#35709)

* add test

* augment test as suggested

* Update tests/utils/test_modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* rerun tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-01-16 17:02:34 +00:00 committed by GitHub
parent 57bf1a12a0
commit aeeceb9916
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,6 +37,7 @@ from transformers import (
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration,
OwlViTForObjectDetection,
PretrainedConfig,
@ -1790,6 +1791,43 @@ class ModelUtilsTest(TestCasePlus):
)
self.assertTrue(check_models_equal(model, model_loaded))
def test_cache_when_needed_at_train_time(self):
"""
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
is at train time used if we request it. Related issue: #35648
"""
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# By default it is not training, we have to set it
self.assertFalse(model.training)
model.train()
# If we set `use_cache=True` while training, then a cache is returned
model_outputs = model(**model_inputs, use_cache=True)
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
self.assertTrue(model.training)
# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
(
model_inputs["attention_mask"],
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
),
dim=1,
)
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
self.assertTrue(model.training)
# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
model_outputs = model(**model_inputs, use_cache=False)
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)
@slow
@require_torch