mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
LLaVaNeXT: pad on right if training (#32134)
* pad on right if training * docs * add tests
This commit is contained in:
parent
251a2409c6
commit
3aefb4ec7f
@ -43,6 +43,13 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.
|
||||
|
||||
We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:
|
||||
|
@ -46,6 +46,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.
|
||||
|
||||
We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
|
||||
|
@ -518,8 +518,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
||||
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
||||
|
||||
left_padding = True
|
||||
if batch_size > 1:
|
||||
left_padding = True if not self.training else False
|
||||
if batch_size > 1 and not self.training:
|
||||
if _left_padding and not _right_padding:
|
||||
left_padding = True
|
||||
elif not _left_padding and _right_padding:
|
||||
|
@ -562,8 +562,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
||||
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
||||
|
||||
left_padding = True
|
||||
if batch_size > 1:
|
||||
left_padding = True if not self.training else False
|
||||
if batch_size > 1 and not self.training:
|
||||
if _left_padding and not _right_padding:
|
||||
left_padding = True
|
||||
elif not _left_padding and _right_padding:
|
||||
|
@ -123,7 +123,7 @@ class LlavaNextVisionText2TextModelTester:
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 30
|
||||
self.encoder_seq_length = 341
|
||||
self.encoder_seq_length = 342
|
||||
self.image_grid_pinpoints = [[32, 32]]
|
||||
|
||||
def get_config(self):
|
||||
@ -156,9 +156,7 @@ class LlavaNextVisionText2TextModelTester:
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
attention_mask[:, :1] = 0
|
||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
||||
input_ids[:, 1] = config.image_token_index
|
||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||
@ -473,3 +471,37 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
self.processor.decode(output_batched[0], skip_special_tokens=True),
|
||||
self.processor.decode(output_single[0], skip_special_tokens=True),
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_padding_side_when_merging_inputs(self):
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
# model is in eval mode by default so we should get pad on the left side
|
||||
# we can check the first hidden-states (aka inputs embeds)
|
||||
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())
|
||||
|
||||
# otherwise padding is on the right side, so it's last 1414 tokens
|
||||
self.processor.padding_side = "right"
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
model.train()
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())
|
||||
|
@ -124,7 +124,7 @@ class LlavaNextVideoVisionText2TextModelTester:
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 30
|
||||
self.encoder_seq_length = 468
|
||||
self.encoder_seq_length = 469
|
||||
self.image_grid_pinpoints = [[32, 32]]
|
||||
|
||||
def get_config(self):
|
||||
@ -166,9 +166,7 @@ class LlavaNextVideoVisionText2TextModelTester:
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
attention_mask[:, :1] = 0
|
||||
# we are giving 3 images and videos let's make sure we pass in 3 special tokens
|
||||
input_ids[:, 1] = config.image_token_index
|
||||
input_ids[:, 2] = config.video_token_index
|
||||
@ -453,3 +451,39 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
self.processor.decode(output_batched[0], skip_special_tokens=True),
|
||||
self.processor.decode(output_single[0], skip_special_tokens=True),
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_padding_side_when_merging_inputs(self):
|
||||
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
|
||||
)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt_video, self.prompt_image],
|
||||
images=[self.image],
|
||||
videos=[self.video],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
# model is in eval mode by default so we should get pad on the left side
|
||||
# we can check the first hidden-states (aka inputs embeds)
|
||||
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
|
||||
|
||||
# otherwise padding is on the right side, so it's last 1482 tokens
|
||||
self.processor.padding_side = "right"
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt_video, self.prompt_image],
|
||||
images=[self.image],
|
||||
videos=[self.video],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
model.train()
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())
|
||||
|
Loading…
Reference in New Issue
Block a user