Better llava next. (#29850)

* Better llava next.
- Batched forward with multiple image of different sizes (number of patches).
- Support training, for cases without any image.
- Support multi-image in same sequence. e.g: ["<image> <image> the first image is a dog while the second is a cat", "<image> <image> <image> <image> these 4 image are..."]

Current limitation:
- Haven't done testing
- Only support right padding (for training)
- left padding (batched generation) is not ready yet.
- PR not ready.

* fix bugs in batched generation

* add tests

* fix batch-gen bugs, left-padding positions and incorrect attention mask

* remove better modeling llava

* fix formatting

* fix test

* fix test

* fix testing

* fix test

* fix formatting

* Update src/transformers/models/llava_next/modeling_llava_next.py

add clarity

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update modeling_llava_next.py

remove assert

* fix bug modeling_llava_next.py

* update modeling

* fix bugs

* fix format

* fix error

* fix new_token_positions

* Update modeling_llava_next.py

* update formatting

* add args

* removecomments

* add slow tests for batched inference

* failing tf/flax tests

* this one ic correct

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix docs

* make fixup

* more fixup

* add test for batch equivalence

* Update tests/models/llava_next/test_modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/image_processing_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/image_processing_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pr comments

* hardcode padding side for bs=1

* update

* [run-slow] llava_next

* [run-slow] llava_next

* make fix-copies

---------

Co-authored-by: NGUYEN, Xuan Phi <x.nguyen@alibaba-inc.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
This commit is contained in:
Xuan-Phi Nguyen 2024-05-15 07:02:56 -07:00 committed by GitHub
parent bdfefbadaf
commit 5ca085b882
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 570 additions and 108 deletions

View File

@ -15,12 +15,13 @@
"""Image processor class for LLaVa-NeXT.""" """Image processor class for LLaVa-NeXT."""
import math import math
from typing import Dict, List, Optional, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_transforms import ( from ...image_transforms import (
PaddingMode,
convert_to_rgb, convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
pad, pad,
@ -154,6 +155,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
""" """
@ -173,6 +177,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = True,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -251,6 +256,74 @@ class LlavaNextImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
) )
def pad(
self,
image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
as input.
Args:
image (`np.ndarray`):
The image to pad.
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
if isinstance(padding, int) or len(padding) != 4:
return pad(image, padding, mode, constant_values, data_format, input_data_format)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if mode == PaddingMode.CONSTANT:
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
elif mode == PaddingMode.REPLICATE:
image = np.pad(image, padding, mode="edge")
elif mode == PaddingMode.SYMMETRIC:
image = np.pad(image, padding, mode="symmetric")
else:
raise ValueError(f"Invalid padding mode: {mode}")
image = (
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
)
return image
def _preprocess( def _preprocess(
self, self,
images: ImageInput, images: ImageInput,
@ -378,7 +451,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
paste_x = (target_width - new_width) // 2 paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2 paste_y = (target_height - new_height) // 2
padded_image = pad(image, padding=((paste_y, paste_y), (paste_x, paste_x))) padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
return padded_image return padded_image
@ -446,6 +519,45 @@ class LlavaNextImageProcessor(BaseImageProcessor):
return image_patches return image_patches
def _pad_for_batching(
self,
pixel_values: List[np.ndarray],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[np.ndarray]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
List[`np.ndarray`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
self.pad(
image,
padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
data_format=data_format,
input_data_format=input_data_format,
)
for image in pixel_values
]
return pixel_values
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
@ -460,6 +572,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = True,
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
@ -496,6 +609,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`. `True`.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
@ -516,6 +632,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
@ -603,6 +720,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
pixel_values = np.array(pixel_values) pixel_values = np.array(pixel_values)
new_images.append(pixel_values) new_images.append(pixel_values)
data = {"pixel_values": new_images, "image_sizes": image_sizes} if do_pad:
processed_images = self._pad_for_batching(new_images)
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
)

View File

@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """PyTorch Llava-NeXT model."""
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -61,10 +62,55 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if not isinstance(grid_pinpoints, list): if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists") raise ValueError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise ValueError(
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()
height, width = select_best_resolution(image_size, grid_pinpoints) height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise ValueError(f"image_size invalid type {type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
def unpad_image(tensor, original_size): def unpad_image(tensor, original_size):
""" """
Unpads a PyTorch tensor of a padded and resized image. Unpads a PyTorch tensor of a padded and resized image.
@ -310,8 +356,19 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
config.text_config, attn_implementation=config._attn_implementation config.text_config, attn_implementation=config._attn_implementation
) )
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init() self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
@ -348,28 +405,170 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
self.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings
return model_embeds return model_embeds
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features def _merge_input_ids_with_image_features(
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): self,
num_images, num_image_patches, embed_dim = image_features.shape image_features,
batch_size, sequence_length = input_ids.shape feature_lens,
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) inputs_embeds,
# 1. Create a mask to know where special image tokens are input_ids,
special_image_token_mask = input_ids == self.config.image_token_index attention_mask,
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) position_ids=None,
# Compute the maximum embed dimension labels=None,
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length image_token_index=None,
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
# 2. Compute the positions where text should be written Args:
# Calculate new positions for text tokens in merged image-text sequence. image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. All vision vectors of all images in the batch
# `torch.cumsum` computes how each image token shifts subsequent text token positions. feature_lens (`torch.LongTensor` of shape `(num_images)`):
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. The length of visual embeddings of each image as stacked in `image_features`
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] Token embeddings before merging with visual embeddings
if left_padding: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
new_token_positions += nb_image_pad[:, None] # offset for left padding Input_ids of tokens, possibly filled with image token
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
```python
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
prompts = [
"[INST] <image>\nWhat is shown in this image? [/INST]",
"[INST] <image>\nWhat is shown in this image? [/INST]",
]
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
chart_img has 2634 tokens, while cat_img has 2340 tokens
```
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
with torch.no_grad():
# ! in llava 1.6, number of patches is variable
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
# both side is 1, so cannot tell
left_padding = self.padding_side == "left"
else:
# invalid attention_mask
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position # 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros( final_embedding = torch.zeros(
@ -378,10 +577,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
final_attention_mask = torch.zeros( final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
) )
final_labels = None
if labels is not None: if labels is not None:
final_labels = torch.full( final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device target_device = inputs_embeds.device
@ -400,32 +598,89 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full( with torch.no_grad():
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device image_to_overwrite = torch.full(
) (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
) )
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. return final_embedding, final_attention_mask, position_ids, final_labels
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0 def pack_image_features(self, image_features, image_sizes, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
if labels is None: Args:
final_labels = None image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
return final_embedding, final_attention_mask, final_labels, position_ids image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -493,14 +748,34 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
# 1. Extract the input embeddings # 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids) # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
for_inputs_embeds_ids = input_ids.clone()
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
# 2. Merge text and images # 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1: if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
batch_size, num_patches, num_channels, height, width = pixel_values.shape # ! infer image_num_patches from image_sizes
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) image_num_patches = [
image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True) image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer] selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default": if vision_feature_select_strategy == "default":
@ -510,55 +785,31 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images image_features = torch.split(image_features, image_num_patches, dim=0)
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
new_image_features = [] image_features, feature_lens = self.pack_image_features(
for image_idx, image_feature in enumerate(image_features): image_features,
if image_feature.shape[0] > 1: image_sizes,
base_image_feature = image_feature[0] image_newline=self.image_newline,
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
) )
if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
# pixel_values is not None but is empty ---> text only cases
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
# there are no images
pass
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache # generation with cache
@ -591,6 +842,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model( outputs = self.language_model(

View File

@ -16,7 +16,6 @@
Processor class for LLaVa-NeXT. Processor class for LLaVa-NeXT.
""" """
from typing import List, Optional, Union from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
@ -53,7 +52,8 @@ class LlavaNextProcessor(ProcessorMixin):
images: ImageInput = None, images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False, padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None, truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None, max_length: Optional[int] = None,
do_pad: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature: ) -> BatchFeature:
""" """
@ -82,6 +82,9 @@ class LlavaNextProcessor(ProcessorMixin):
lengths). lengths).
max_length (`int`, *optional*): max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above). Maximum length of the returned list and optionally padding length (see above).
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
truncation (`bool`, *optional*): truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`. Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
@ -102,7 +105,7 @@ class LlavaNextProcessor(ProcessorMixin):
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if images is not None: if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors) image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
else: else:
image_inputs = {} image_inputs = {}
text_inputs = self.tokenizer( text_inputs = self.tokenizer(

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Llava-NeXT model. """ """Testing suite for the PyTorch Llava-NeXT model."""
import gc import gc
import unittest import unittest
@ -46,6 +46,8 @@ from ...test_modeling_common import (
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
else: else:
is_torch_greater_or_equal_than_2_0 = False is_torch_greater_or_equal_than_2_0 = False
@ -121,7 +123,7 @@ class LlavaNextVisionText2TextModelTester:
self.batch_size = 3 self.batch_size = 3
self.num_channels = 3 self.num_channels = 3
self.image_size = 30 self.image_size = 30
self.encoder_seq_length = 342 self.encoder_seq_length = 341
self.image_grid_pinpoints = [[32, 32]] self.image_grid_pinpoints = [[32, 32]]
def get_config(self): def get_config(self):
@ -153,10 +155,15 @@ class LlavaNextVisionText2TextModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
attention_mask = input_ids.ne(1).to(torch_device) # make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images let's make sure we pass in 3 image tokens # we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index input_ids[:, 1] = config.image_token_index
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
# maskout where the image token is
labels[:, 1] == self.ignore_index
inputs_dict = { inputs_dict = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"image_sizes": torch.tensor( "image_sizes": torch.tensor(
@ -164,6 +171,7 @@ class LlavaNextVisionText2TextModelTester:
), ),
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"labels": labels,
} }
return config, inputs_dict return config, inputs_dict
@ -341,10 +349,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
padding=True, padding=True,
).to(torch_device) ).to(torch_device)
# make sure image_sizes are the same # it should not matter whether two images are the same size or not
# as otherwise batched generation doesn't work
inputs.image_sizes[1] = inputs.image_sizes[0]
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
@ -378,3 +383,85 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor.decode(output[0], skip_special_tokens=True), self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT, EXPECTED_DECODED_TEXT,
) )
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch_different_resolutions(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
pixel_values = inputs["pixel_values"]
# verify pixel values are padded correctly with 0 when one image has more num_patches than the other
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=model.config.image_grid_pinpoints,
patch_size=model.config.vision_config.image_size,
)
for imsize in inputs["image_sizes"]
]
for pix_val, num_patch in zip(pixel_values, image_num_patches):
self.assertTrue(torch.all(pix_val[num_patch:] == 0)) # pad on the right
for i in range(num_patch):
self.assertFalse(torch.all(pix_val[i : i + 1] == 0)) # no padding expected in any of patches
# check loss when labels are passed
inputs["labels"] = inputs["input_ids"].clone()
with torch.no_grad():
output = model(**inputs)
expected_slice = torch.tensor(
[[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]],
dtype=torch.float32,
device=torch_device,
)
assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3)
assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device))
# verify generation
output = model.generate(**inputs, max_new_tokens=50)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch_matches_single(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to(
torch_device
)
# verify generation
output_batched = model.generate(**inputs_batched, max_new_tokens=50)
output_single = model.generate(**inputs_single, max_new_tokens=50)
self.assertEqual(
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)