LLaVaNeXT: pad on right if training (#32134)

* pad on right if training

* docs

* add tests
This commit is contained in:
Raushan Turganbay 2024-07-23 10:23:55 +05:00 committed by GitHub
parent 251a2409c6
commit 3aefb4ec7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 90 additions and 10 deletions

View File

@ -43,6 +43,13 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
<Tip warning={true}>
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
</Tip>
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.
We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:

View File

@ -46,6 +46,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
<Tip warning={true}>
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
</Tip>
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.
We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:

View File

@ -518,8 +518,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
left_padding = True if not self.training else False
if batch_size > 1 and not self.training:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:

View File

@ -562,8 +562,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
left_padding = True if not self.training else False
if batch_size > 1 and not self.training:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:

View File

@ -123,7 +123,7 @@ class LlavaNextVisionText2TextModelTester:
self.batch_size = 3
self.num_channels = 3
self.image_size = 30
self.encoder_seq_length = 341
self.encoder_seq_length = 342
self.image_grid_pinpoints = [[32, 32]]
def get_config(self):
@ -156,9 +156,7 @@ class LlavaNextVisionText2TextModelTester:
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
@ -473,3 +471,37 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)
@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())
# otherwise padding is on the right side, so it's last 1414 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())

View File

@ -124,7 +124,7 @@ class LlavaNextVideoVisionText2TextModelTester:
self.batch_size = 3
self.num_channels = 3
self.image_size = 30
self.encoder_seq_length = 468
self.encoder_seq_length = 469
self.image_grid_pinpoints = [[32, 32]]
def get_config(self):
@ -166,9 +166,7 @@ class LlavaNextVideoVisionText2TextModelTester:
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images and videos let's make sure we pass in 3 special tokens
input_ids[:, 1] = config.image_token_index
input_ids[:, 2] = config.video_token_index
@ -453,3 +451,39 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)
@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)
# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
# otherwise padding is on the right side, so it's last 1482 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)
model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())