Fixes #37219 : RecurrentGemma crashes for inputs longer than sliding window length (#37613)

* fix: RecurrentGemma crashes during inference for inputs longer than sliding window width

* fix recurrentgemma tests; add long test bigger than context window
This commit is contained in:
Manuel de Prada Corral 2025-04-22 12:21:16 +02:00 committed by GitHub
parent 964a1b6b7d
commit 413f9bbf80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 13 deletions

View File

@ -767,6 +767,8 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
# Crop the attention mask to the target length.
attention_mask = attention_mask[:, -target_length:]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

View File

@ -278,11 +278,12 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
@slow
class RecurrentGemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
input_long_text = ['<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'] # fmt: skip
model_id = "google/recurrentgemma-2b"
@require_read_token
def test_2b_generate(self):
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a new app that allows you to make money by watching videos.\n\nThe app is very simple to use and you can earn money by watching videos.\n\nThe app is available for both Android and iOS devices and you can download it from the Google Play Store or the App Store.\n\nOnce you have downloaded the app'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@ -296,7 +297,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
tokenizer.padding_side = "left"
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best <strong><em>free online video editing software</em></strong>.\n\n<h2><strong>Best Free Online Video Editing Software</strong></h2>\n\n<strong>1.</strong> <strong>Wondershare Filmora</strong>\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today Im going to show you how to make a simple and easy to make a <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY'] # fmt: skip
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
@ -316,7 +317,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_2b_sample(self):
set_seed(0)
EXPECTED_TEXT = ['Where is Paris ?\n\nAnswer this question "yes" or "no": Could a person pass out in subzero temperatures?\n\nFor the sentence below, underline the pronoun in parentheses that agrees with its antecedent.\n\nExample 1. Mary and Pam will have the opportunity to prove (herself, $\\underline{\\text{themselves}}$) at the concert.\n\nThe waiters and the manager at the restaurant will do <em>(his, their)</em> best to assist you.\n\nA vocabulary word appears in italics in the short passage below. Think about how the word is used. Then write a definition for the vocabulary word.\n\nAfter a one-hour $'] # fmt: skip
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@ -329,13 +330,13 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_2b_8bit(self):
EXPECTED_TEXTS = ['<bos>Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "<bos>Hi today<pad><pad> I'm going to show you how to make a simple and easy to use <strong><em><u>"] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@ -345,18 +346,28 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_long_context(self):
input_text = [
'<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'
]
EXPECTED_GENERATION = [
' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could'
]
EXPECTED_GENERATION = [' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." He added that the BEA is "not aware of any video footage that could have been taken on board the plane." The BEA is the French equivalent of the National Transportation Safety Board'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_GENERATION)
@require_read_token
def test_longer_than_window(self):
EXPECTED_GENERATION = [" Robin's comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the"] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
model.config.attention_window_size = 256 # Make the attention window size shorter than the current prompt
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_GENERATION)