mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
* fix: RecurrentGemma crashes during inference for inputs longer than sliding window width * fix recurrentgemma tests; add long test bigger than context window
This commit is contained in:
parent
964a1b6b7d
commit
413f9bbf80
@ -767,6 +767,8 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
# Crop the attention mask to the target length.
|
||||
attention_mask = attention_mask[:, -target_length:]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
@ -278,11 +278,12 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
@slow
|
||||
class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
input_long_text = ['<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'] # fmt: skip
|
||||
model_id = "google/recurrentgemma-2b"
|
||||
|
||||
@require_read_token
|
||||
def test_2b_generate(self):
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a new app that allows you to make money by watching videos.\n\nThe app is very simple to use and you can earn money by watching videos.\n\nThe app is available for both Android and iOS devices and you can download it from the Google Play Store or the App Store.\n\nOnce you have downloaded the app'] # fmt: skip
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
@ -296,7 +297,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best <strong><em>free online video editing software</em></strong>.\n\n<h2><strong>Best Free Online Video Editing Software</strong></h2>\n\n<strong>1.</strong> <strong>Wondershare Filmora</strong>\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I’m going to show you how to make a simple and easy to make a <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY'] # fmt: skip
|
||||
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
@ -316,7 +317,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_2b_sample(self):
|
||||
set_seed(0)
|
||||
EXPECTED_TEXT = ['Where is Paris ?\n\nAnswer this question "yes" or "no": Could a person pass out in subzero temperatures?\n\nFor the sentence below, underline the pronoun in parentheses that agrees with its antecedent.\n\nExample 1. Mary and Pam will have the opportunity to prove (herself, $\\underline{\\text{themselves}}$) at the concert.\n\nThe waiters and the manager at the restaurant will do <em>(his, their)</em> best to assist you.\n\nA vocabulary word appears in italics in the short passage below. Think about how the word is used. Then write a definition for the vocabulary word.\n\nAfter a one-hour $'] # fmt: skip
|
||||
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
@ -329,13 +330,13 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_2b_8bit(self):
|
||||
EXPECTED_TEXTS = ['<bos>Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "<bos>Hi today<pad><pad> I'm going to show you how to make a simple and easy to use <strong><em><u>"] # fmt: skip
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@ -345,18 +346,28 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_read_token
|
||||
def test_long_context(self):
|
||||
input_text = [
|
||||
'<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'
|
||||
]
|
||||
EXPECTED_GENERATION = [
|
||||
' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could'
|
||||
]
|
||||
EXPECTED_GENERATION = [' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." He added that the BEA is "not aware of any video footage that could have been taken on board the plane." The BEA is the French equivalent of the National Transportation Safety Board'] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
|
||||
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
|
||||
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
||||
print(output_text)
|
||||
self.assertEqual(output_text, EXPECTED_GENERATION)
|
||||
|
||||
@require_read_token
|
||||
def test_longer_than_window(self):
|
||||
EXPECTED_GENERATION = [" Robin's comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the"] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
model.config.attention_window_size = 256 # Make the attention window size shorter than the current prompt
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
|
||||
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
|
||||
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
||||
self.assertEqual(output_text, EXPECTED_GENERATION)
|
||||
|
Loading…
Reference in New Issue
Block a user