diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 62b571b2928..f483f79cd51 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -55,15 +55,14 @@ class DbrxRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.register_buffer("inv_freq", None, persistent=False) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8477dc2b5df..d511e50dd9f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -104,15 +104,14 @@ class GemmaRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.register_buffer("inv_freq", None, persistent=False) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index f8dc2c96846..8b04575cfb6 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -397,15 +397,14 @@ class JetMoeRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.register_buffer("inv_freq", None, persistent=False) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index db88d607a36..71fcddb4644 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -99,15 +99,14 @@ class Phi3RotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.register_buffer("inv_freq", None, persistent=False) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a115dc687c5..06aef9a03d8 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -68,16 +68,14 @@ class RecurrentGemmaRotaryEmbedding(nn.Module): super().__init__() self.dim = dim self.base = base - self.register_buffer("inv_freq", None, persistent=False) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 80f275e54ce..7cc2ebe7bd2 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -17,6 +17,7 @@ import tempfile import unittest import pytest +from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers.testing_utils import ( @@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel + from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer class GemmaModelTester: @@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.6] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/gemma-2b" + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name @@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase): output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + + @slow + @require_torch_gpu + @require_read_token + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found", + "Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the", + ], + 7: [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found", + "Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the", + ], + } + + prompts = ["Hello I am doing", "Hi today"] + tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") + model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 5d402bd8599..0e563b114b0 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf" + def setUp(self): self.model_tester = LlamaModelTester(self) self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b2f96c72698..533eedb7f33 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,6 +27,7 @@ from collections import defaultdict from typing import Dict, List, Tuple import numpy as np +from packaging import version from parameterized import parameterized from pytest import mark @@ -35,6 +36,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, + AutoTokenizer, PretrainedConfig, PreTrainedModel, is_torch_available, @@ -71,6 +73,7 @@ from transformers.testing_utils import ( require_accelerate, require_bitsandbytes, require_flash_attn, + require_read_token, require_safetensors, require_torch, require_torch_gpu, @@ -4399,6 +4402,38 @@ class ModelTesterMixin: normalized_1 = F.softmax(out_shared_prefix_last_tokens) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + # For now, Let's focus only on GPU for `torch.compile` + @slow + @require_torch_gpu + @require_read_token + def test_torch_compile(self): + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest("This test requires torch >= 2.3 to run.") + + if not hasattr(self, "_torch_compile_test_ckpt"): + self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.") + ckpt = self._torch_compile_test_ckpt + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + batch_size = 1 + n_iter = 3 + + tokenizer = AutoTokenizer.from_pretrained(ckpt) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) + + model.generation_config.max_new_tokens = 4 + model.generation_config.max_new_tokens = 4 + + model.generation_config.cache_implementation = "static" + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + + input_text = "Why dogs are cute?" + input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device) + + for i in range(n_iter): + _ = model.generate(**input_ids, do_sample=False) + global_rng = random.Random()