[InstructBlip] Add instruct blip int8 test (#24555)

* add 8bit instructblip test

* update tests
This commit is contained in:
Younes Belkada 2023-06-28 19:06:30 +02:00 committed by GitHub
parent c70c88a268
commit 33b5ef5cdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 24 deletions

View File

@ -1030,7 +1030,7 @@ class InstructBlipQFormerEmbeddings(nn.Module):
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
embeddings = embeddings + position_embeddings
if query_embeds is not None:

View File

@ -19,9 +19,10 @@ Processor class for InstructBLIP. Largely copy of Blip2Processor with addition o
import os
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ..auto import AutoTokenizer
@ -71,7 +72,7 @@ class InstructBlipProcessor(ProcessorMixin):
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchEncoding:
) -> BatchFeature:
"""
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
@ -81,7 +82,7 @@ class InstructBlipProcessor(ProcessorMixin):
if images is None and text is None:
raise ValueError("You have to specify at least images or text.")
encoding = BatchEncoding()
encoding = BatchFeature()
if text is not None:
text_encoding = self.tokenizer(

View File

@ -521,51 +521,39 @@ def prepare_img():
@require_torch
@slow
class InstructBlipModelIntegrationTest(unittest.TestCase):
# TODO (@Younes): Re-enable this when 8-bit or 4-bit is implemented.
@unittest.skip(reason="GPU OOM")
def test_inference_vicuna_7b(self):
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b").to(
torch_device
model = InstructBlipForConditionalGeneration.from_pretrained(
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True
)
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
prompt = "What is unusual about this image?"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
# verify logits
with torch.no_grad():
logits = model(**inputs).logits
expected_slice = torch.tensor(
[[-3.4684, -12.6759, 8.5067], [-5.1305, -12.2058, 7.9834], [-4.0632, -13.9285, 9.2327]],
[[-3.5410, -12.2812, 8.2812], [-5.2500, -12.0938, 7.8398], [-4.1523, -13.8281, 9.0000]],
device=torch_device,
)
assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-5)
self.assertTrue(torch.allclose(logits[0, :3, :3].float(), expected_slice, atol=1e-3))
# verify generation
outputs = model.generate(
**inputs,
do_sample=False,
num_beams=5,
max_length=256,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1,
)
outputs = model.generate(**inputs, max_new_tokens=30)
outputs[outputs == 0] = 2
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
# fmt: off
expected_outputs = [2, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 29892, 607, 338, 14089, 287, 297, 278, 7256, 310, 263, 19587, 4272, 11952, 29889, 910, 338, 385, 443, 535, 794, 1848, 2948, 304, 13977, 292, 22095, 29892, 408, 372, 6858, 278, 767, 304, 17346, 3654, 322, 670, 13977, 292, 21083, 373, 2246, 310, 278, 19716, 1550, 12402, 1218, 1549, 12469, 29889, 19814, 29892, 278, 10122, 310, 8818, 275, 322, 916, 24413, 297, 278, 9088, 4340, 19310, 7093, 278, 22910, 5469, 310, 445, 6434, 29889, 2, 1]
expected_outputs = [ 2, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 1623, 263, 19587, 4272, 11952, 29889]
# fmt: on
self.assertEqual(outputs[0].tolist(), expected_outputs)
self.assertEqual(
generated_text,
"The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.",
"The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving down a busy city street.",
)
def test_inference_flant5_xl(self):