mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[BLIP-2] Improve conversion script (#24854)
* Improve conversion script * Add int8 code example * Update tip * Fix code * Fix code snippet * Add nucleus sampling * More improvements * Address comments * Address comments
This commit is contained in:
parent
17fdd35481
commit
5469c18762
@ -90,7 +90,7 @@ class Blip2VisionConfig(PretrainedConfig):
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=0.00001,
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=1e-10,
|
||||
qkv_bias=True,
|
||||
|
@ -24,7 +24,8 @@ import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install git+https://github.com/nielsrogge/LAVIS.git@fix_lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
@ -37,6 +38,7 @@ from transformers import (
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
@ -145,11 +147,16 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
hf_model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
lavis_device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=device
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
@ -185,61 +192,53 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(device)
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(device)
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values)
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(device)
|
||||
hf_model.to(device)
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(original_pixel_values, input_ids).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(original_pixel_values, input_ids, labels=labels).logits
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
if model_name == "blip2-flan-t5-xl":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-41.5850, -4.4440, -8.9922], [-47.4322, -5.9143, -1.7340]], device=device
|
||||
)
|
||||
assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
|
||||
elif model_name == "blip2-flan-t5-xl-coco":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-57.0109, -9.8967, -12.6280], [-68.6578, -12.7191, -10.5065]], device=device
|
||||
)
|
||||
else:
|
||||
# cast to same type
|
||||
target_dtype = logits.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits, atol=1e-2)
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = ""
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
original_outputs = original_model.generate({"image": original_pixel_values})
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
original_pixel_values,
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
@ -248,10 +247,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
print("Original generation:", original_outputs)
|
||||
prompt_length = input_ids.shape[1]
|
||||
output_text = processor.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
|
@ -1556,6 +1556,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
|
||||
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
||||
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16.
|
||||
|
||||
</Tip>
|
||||
""",
|
||||
BLIP_2_START_DOCSTRING,
|
||||
)
|
||||
@ -1687,15 +1693,40 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
|
||||
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
... "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> prompt = "Question: how many cats are there? Answer:"
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
|
||||
|
||||
>>> generated_ids = model.generate(**inputs)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
>>> print(generated_text)
|
||||
two
|
||||
```
|
||||
|
||||
Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
|
||||
This greatly reduces the amount of memory used by the model while maintaining the same performance.
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
>>> import torch
|
||||
|
||||
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> prompt = "Question: how many cats are there? Answer:"
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
>>> generated_ids = model.generate(**inputs)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
|
@ -72,6 +72,7 @@ src/transformers/models/blip/image_processing_blip.py
|
||||
src/transformers/models/blip/modeling_blip.py
|
||||
src/transformers/models/blip/modeling_tf_blip.py
|
||||
src/transformers/models/blip/processing_blip.py
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/blip_2/processing_blip_2.py
|
||||
src/transformers/models/bloom/configuration_bloom.py
|
||||
src/transformers/models/bloom/tokenization_bloom_fast.py
|
||||
|
Loading…
Reference in New Issue
Block a user