mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[cache] add a test to confirm we can use cache at train time (#35709)
* add test * augment test as suggested * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * rerun tests --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
57bf1a12a0
commit
aeeceb9916
@ -37,6 +37,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForSequenceClassification,
|
||||
DynamicCache,
|
||||
LlavaForConditionalGeneration,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
@ -1790,6 +1791,43 @@ class ModelUtilsTest(TestCasePlus):
|
||||
)
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_cache_when_needed_at_train_time(self):
|
||||
"""
|
||||
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
|
||||
is at train time used if we request it. Related issue: #35648
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
|
||||
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
|
||||
# By default it is not training, we have to set it
|
||||
self.assertFalse(model.training)
|
||||
model.train()
|
||||
|
||||
# If we set `use_cache=True` while training, then a cache is returned
|
||||
model_outputs = model(**model_inputs, use_cache=True)
|
||||
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
# simulate injecting virtual tokens like in prefix tuning
|
||||
num_virtual_tokens = 3
|
||||
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
model_inputs["attention_mask"] = torch.cat(
|
||||
(
|
||||
model_inputs["attention_mask"],
|
||||
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
|
||||
model_outputs = model(**model_inputs, use_cache=False)
|
||||
self.assertIsNone(model_outputs.past_key_values)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user