From b75ad56620431984a44a962c98136c8571b4fca9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 31 Jul 2024 11:12:46 +0100 Subject: [PATCH] Llama 3.1: Fix incorrect `inv_freq` assignment (#32330) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix 💩 --- src/transformers/modeling_rope_utils.py | 8 +++--- tests/models/llama/test_modeling_llama.py | 30 ++++++++++++++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 61315e18d41..66a0afcef6b 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -328,14 +328,14 @@ def _compute_llama3_parameters( wavelen = 2 * math.pi / inv_freq # wavelen < high_freq_wavelen: do nothing # wavelen > low_freq_wavelen: divide by factor - inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) # otherwise: interpolate between the two, using a smooth factor smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - return inv_freq, attention_factor + return inv_freq_llama, attention_factor # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 19cb7bd6be3..de17c088ed7 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -22,7 +22,7 @@ import pytest from packaging import version from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed +from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -718,6 +718,34 @@ class LlamaIntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + @slow + @require_read_token + def test_llama_3_1_hard(self): + """ + An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences + from llama 3.1.'s RoPE can be detected + """ + EXPECTED_TEXT = ( + "Tell me about the french revolution. The french revolution was a period of radical social and political " + "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " + "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " + "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " + "assembly that had not met since 1614. The Third Estate, which represented the common people, " + "demanded greater representation and eventually broke away to form the National Assembly. This marked " + "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" + ) + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) + @slow @require_read_token def test_model_7b_logits_bf16(self):