mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Better llava next. (#29850)
* Better llava next. - Batched forward with multiple image of different sizes (number of patches). - Support training, for cases without any image. - Support multi-image in same sequence. e.g: ["<image> <image> the first image is a dog while the second is a cat", "<image> <image> <image> <image> these 4 image are..."] Current limitation: - Haven't done testing - Only support right padding (for training) - left padding (batched generation) is not ready yet. - PR not ready. * fix bugs in batched generation * add tests * fix batch-gen bugs, left-padding positions and incorrect attention mask * remove better modeling llava * fix formatting * fix test * fix test * fix testing * fix test * fix formatting * Update src/transformers/models/llava_next/modeling_llava_next.py add clarity Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_llava_next.py remove assert * fix bug modeling_llava_next.py * update modeling * fix bugs * fix format * fix error * fix new_token_positions * Update modeling_llava_next.py * update formatting * add args * removecomments * add slow tests for batched inference * failing tf/flax tests * this one ic correct * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix docs * make fixup * more fixup * add test for batch equivalence * Update tests/models/llava_next/test_modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/image_processing_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/image_processing_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/modeling_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pr comments * hardcode padding side for bs=1 * update * [run-slow] llava_next * [run-slow] llava_next * make fix-copies --------- Co-authored-by: NGUYEN, Xuan Phi <x.nguyen@alibaba-inc.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
This commit is contained in:
parent
bdfefbadaf
commit
5ca085b882
@ -15,12 +15,13 @@
|
||||
"""Image processor class for LLaVa-NeXT."""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
pad,
|
||||
@ -154,6 +155,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
|
||||
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
@ -173,6 +177,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = True,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -251,6 +256,74 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
|
||||
mode: PaddingMode = PaddingMode.CONSTANT,
|
||||
constant_values: Union[float, Iterable[float]] = 0.0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
|
||||
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
|
||||
as input.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to pad.
|
||||
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
|
||||
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
||||
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
||||
- `((before, after),)` yields same before and after pad for height and width.
|
||||
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
||||
mode (`PaddingMode`):
|
||||
The padding mode to use. Can be one of:
|
||||
- `"constant"`: pads with a constant value.
|
||||
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
||||
vector along each axis.
|
||||
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
||||
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
||||
constant_values (`float` or `Iterable[float]`, *optional*):
|
||||
The value to use for the padding if `mode` is `"constant"`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
|
||||
"""
|
||||
|
||||
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
|
||||
if isinstance(padding, int) or len(padding) != 4:
|
||||
return pad(image, padding, mode, constant_values, data_format, input_data_format)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if mode == PaddingMode.CONSTANT:
|
||||
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
|
||||
elif mode == PaddingMode.REFLECT:
|
||||
image = np.pad(image, padding, mode="reflect")
|
||||
elif mode == PaddingMode.REPLICATE:
|
||||
image = np.pad(image, padding, mode="edge")
|
||||
elif mode == PaddingMode.SYMMETRIC:
|
||||
image = np.pad(image, padding, mode="symmetric")
|
||||
else:
|
||||
raise ValueError(f"Invalid padding mode: {mode}")
|
||||
image = (
|
||||
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||
)
|
||||
return image
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
@ -378,7 +451,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
@ -446,6 +519,45 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
|
||||
return image_patches
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: List[np.ndarray],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[np.ndarray]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
List[`np.ndarray`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
self.pad(
|
||||
image,
|
||||
padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
@ -460,6 +572,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = True,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
@ -496,6 +609,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
||||
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
|
||||
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
@ -516,6 +632,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -603,6 +720,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
pixel_values = np.array(pixel_values)
|
||||
new_images.append(pixel_values)
|
||||
|
||||
data = {"pixel_values": new_images, "image_sizes": image_sizes}
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(new_images)
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
||||
)
|
||||
|
@ -12,12 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
"""PyTorch Llava-NeXT model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
@ -61,10 +62,55 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise ValueError(
|
||||
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||
)
|
||||
image_size = image_size.tolist()
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
"""
|
||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):
|
||||
The size of the input image in the format (height, width). ?
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
int: the number of patches
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise ValueError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
||||
image_size = image_size.tolist()
|
||||
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
@ -310,8 +356,19 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
config.text_config, attn_implementation=config._attn_implementation
|
||||
)
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def padding_side(self):
|
||||
return self._padding_side
|
||||
|
||||
@padding_side.setter
|
||||
def padding_side(self, padding_side: str):
|
||||
if padding_side not in ["left", "right"]:
|
||||
raise ValueError(f"{padding_side} is not `left` or `right`.")
|
||||
self._padding_side = padding_side
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
@ -348,28 +405,170 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids=None,
|
||||
labels=None,
|
||||
image_token_index=None,
|
||||
ignore_index=-100,
|
||||
):
|
||||
"""
|
||||
Merge input_ids with with image features into final embeddings
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
Args:
|
||||
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
|
||||
All vision vectors of all images in the batch
|
||||
feature_lens (`torch.LongTensor` of shape `(num_images)`):
|
||||
The length of visual embeddings of each image as stacked in `image_features`
|
||||
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
|
||||
Token embeddings before merging with visual embeddings
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Input_ids of tokens, possibly filled with image token
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
|
||||
:abels need to be recalculated to support training (if provided)
|
||||
image_token_index (`int`, *optional*)
|
||||
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
|
||||
ignore_index (`int`, *optional*)
|
||||
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
|
||||
Returns:
|
||||
final_embedding, final_attention_mask, position_ids, final_labels
|
||||
|
||||
Explanation:
|
||||
each image has variable length embeddings, with length specified by feature_lens
|
||||
image_features is concatenation of all visual embed vectors
|
||||
task: fill each <image> with the correct number of visual embeddings
|
||||
Example:
|
||||
X (5 patches), Y (3 patches), Z (8)
|
||||
X, Y are in the same sequence (in-context learning)
|
||||
if right padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
o p q r Z s t u v _ _ _ _ _ _
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
|
||||
]
|
||||
elif left padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
_ _ _ _ _ _ o p q r Z s t u v
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
|
||||
]
|
||||
Edge cases:
|
||||
* If tokens are same but image token sizes are different, then cannot infer left or right padding
|
||||
```python
|
||||
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
|
||||
prompts = [
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
]
|
||||
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
|
||||
chart_img has 2634 tokens, while cat_img has 2340 tokens
|
||||
```
|
||||
|
||||
input_ids: [
|
||||
a b c d X g h
|
||||
i j Y k l m n
|
||||
]
|
||||
where X is 3 tokens while Y is 5, this mean after merge
|
||||
if left-padding (batched generation)
|
||||
input_ids should be: [
|
||||
_ _ a b c d X X X g h
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
elif (right padding) (training)
|
||||
input_ids should be: [
|
||||
a b c d X X X g h _ _
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
"""
|
||||
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
|
||||
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
|
||||
|
||||
with torch.no_grad():
|
||||
# ! in llava 1.6, number of patches is variable
|
||||
num_images = feature_lens.size(0)
|
||||
num_image_features, embed_dim = image_features.shape
|
||||
if feature_lens.sum() != num_image_features:
|
||||
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
|
||||
batch_size = input_ids.shape[0]
|
||||
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
||||
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
||||
|
||||
left_padding = True
|
||||
if batch_size > 1:
|
||||
if _left_padding and not _right_padding:
|
||||
left_padding = True
|
||||
elif not _left_padding and _right_padding:
|
||||
left_padding = False
|
||||
elif not _left_padding and not _right_padding:
|
||||
# both side is 1, so cannot tell
|
||||
left_padding = self.padding_side == "left"
|
||||
else:
|
||||
# invalid attention_mask
|
||||
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
||||
|
||||
# Whether to turn off right padding
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == image_token_index
|
||||
# special_image_token_mask: [bsz, seqlen]
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# num_special_image_tokens: [bsz]
|
||||
# Reserve for padding of num_images
|
||||
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
||||
if total_num_special_image_tokens != num_images:
|
||||
raise ValueError(
|
||||
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
|
||||
)
|
||||
# Compute the maximum embed dimension
|
||||
# max_image_feature_lens is max_feature_lens per batch
|
||||
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
|
||||
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
|
||||
embed_sequence_lengths = (
|
||||
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
|
||||
)
|
||||
max_embed_dim = embed_sequence_lengths.max()
|
||||
|
||||
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
# ! instead of special_image_token_mask * (num_image_patches - 1)
|
||||
# special_image_token_mask * (num_feature_len - 1)
|
||||
special_image_token_mask = special_image_token_mask.long()
|
||||
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
|
||||
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
|
||||
if left_padding:
|
||||
# shift right token positions so that they are ending at the same number
|
||||
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
|
||||
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
|
||||
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
@ -378,10 +577,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_labels = None
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
@ -400,32 +598,89 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
with torch.no_grad():
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
|
||||
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
|
||||
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
|
||||
|
||||
if left_padding:
|
||||
# exclude padding on the left
|
||||
val = (max_embed_dim - embed_indices) <= embed_seq_lens
|
||||
else:
|
||||
# exclude padding on the right
|
||||
val = embed_indices < embed_seq_lens
|
||||
image_to_overwrite &= val
|
||||
|
||||
if image_to_overwrite.sum() != num_image_features:
|
||||
raise ValueError(
|
||||
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
|
||||
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. "
|
||||
f"This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
return final_embedding, final_attention_mask, position_ids, final_labels
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
Args:
|
||||
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
Returns:
|
||||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||
feature_lens (`List[int]`)
|
||||
token length of each image in image_features
|
||||
"""
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@ -493,14 +748,34 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# In case image_token_index is not in the embeddings (extra token but embedding don't have it)
|
||||
for_inputs_embeds_ids = input_ids.clone()
|
||||
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
|
||||
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1:
|
||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
||||
image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True)
|
||||
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
# figure out if pixel_values is concatenated or stacked
|
||||
if pixel_values.dim() == 5:
|
||||
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
@ -510,55 +785,31 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
if labels is None:
|
||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
||||
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# pixel_values is not None but is empty ---> text only cases
|
||||
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
|
||||
# there are no images
|
||||
pass
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
@ -591,6 +842,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
outputs = self.language_model(
|
||||
|
@ -16,7 +16,6 @@
|
||||
Processor class for LLaVa-NeXT.
|
||||
"""
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
@ -53,7 +52,8 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
images: ImageInput = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length=None,
|
||||
max_length: Optional[int] = None,
|
||||
do_pad: Optional[bool] = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -82,6 +82,9 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
||||
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
|
||||
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
@ -102,7 +105,7 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, return_tensors=return_tensors)
|
||||
image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
|
||||
else:
|
||||
image_inputs = {}
|
||||
text_inputs = self.tokenizer(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Llava-NeXT model. """
|
||||
"""Testing suite for the PyTorch Llava-NeXT model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
@ -46,6 +46,8 @@ from ...test_modeling_common import (
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
@ -121,7 +123,7 @@ class LlavaNextVisionText2TextModelTester:
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 30
|
||||
self.encoder_seq_length = 342
|
||||
self.encoder_seq_length = 341
|
||||
self.image_grid_pinpoints = [[32, 32]]
|
||||
|
||||
def get_config(self):
|
||||
@ -153,10 +155,15 @@ class LlavaNextVisionText2TextModelTester:
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
attention_mask[:, :1] = 0
|
||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
||||
input_ids[:, 1] = config.image_token_index
|
||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||
# maskout where the image token is
|
||||
labels[:, 1] == self.ignore_index
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_sizes": torch.tensor(
|
||||
@ -164,6 +171,7 @@ class LlavaNextVisionText2TextModelTester:
|
||||
),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -341,10 +349,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
# make sure image_sizes are the same
|
||||
# as otherwise batched generation doesn't work
|
||||
inputs.image_sizes[1] = inputs.image_sizes[0]
|
||||
|
||||
# it should not matter whether two images are the same size or not
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
||||
@ -378,3 +383,85 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch_different_resolutions(self):
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
# verify pixel values are padded correctly with 0 when one image has more num_patches than the other
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=model.config.image_grid_pinpoints,
|
||||
patch_size=model.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in inputs["image_sizes"]
|
||||
]
|
||||
for pix_val, num_patch in zip(pixel_values, image_num_patches):
|
||||
self.assertTrue(torch.all(pix_val[num_patch:] == 0)) # pad on the right
|
||||
for i in range(num_patch):
|
||||
self.assertFalse(torch.all(pix_val[i : i + 1] == 0)) # no padding expected in any of patches
|
||||
|
||||
# check loss when labels are passed
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3)
|
||||
assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device))
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch_matches_single(self):
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output_batched = model.generate(**inputs_batched, max_new_tokens=50)
|
||||
output_single = model.generate(**inputs_single, max_new_tokens=50)
|
||||
self.assertEqual(
|
||||
self.processor.decode(output_batched[0], skip_special_tokens=True),
|
||||
self.processor.decode(output_single[0], skip_special_tokens=True),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user