[Fuyu] Add tests (#27001)

* Add tests

* Add integration test

* More improvements

* Fix tests

* Fix style

* Skip gradient checkpointing tests

* Update script

* Remove scripts

* Remove Fuyu from auto mapping

* Fix integration test

* More improvements

* Remove file

* Add Fuyu to slow documentation tests

* Address comments

* Clarify comment
This commit is contained in:
NielsRogge 2023-11-15 09:33:04 +01:00 committed by GitHub
parent 186c077513
commit cc0dc24bc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 73 deletions

View File

@ -19,10 +19,10 @@ import torch
import torch.utils.checkpoint
from torch import nn
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModelForCausalLM
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_fuyu import FuyuConfig
@ -101,6 +101,11 @@ FUYU_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
hidden size of the model.
image_patches_indices (`torch.LongTensor` of shape `(batch_size, num_total_patches + number_of_newline_tokens + number_of_text_tokens, patch_size_ x patch_size x num_channels )`, *optional*):
Indices indicating at which position the image_patches have to be inserted in input_embeds.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
@ -136,17 +141,10 @@ FUYU_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare Fuyu Model outputting raw hidden-states without any specific head on top.",
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FuyuDecoderLayer`]
Args:
config: FuyuConfig
"""
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -178,12 +176,14 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
embeddings.
Args:
word_embeddings: Tensor of word embeddings. Shape: [b, s, h]
continuous_embeddings:
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is
shape [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
indices in image_patch_input_indices for that batch element.
image_patch_input_indices: Tensor of indices of the image patches in the input_ids tensor. Shape: [b, s]
word_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Tensor of word embeddings.
continuous_embeddings (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is shape
[num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
indices in image_patch_input_indices for that batch element.
image_patch_input_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Tensor of indices of the image patches in the input_ids tensor.
"""
if not (word_embeddings.shape[0] == len(continuous_embeddings)):
raise ValueError(
@ -208,6 +208,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
return output_embeddings
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
@ -218,10 +219,42 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(text=text_prompt, images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> print(generation_text)
'A bus parked on the side of a road.'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -230,15 +263,14 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
raise ValueError("You have to specify either input_is or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
@ -273,10 +305,12 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
labels=labels,
use_cache=use_cache,
return_dict=return_dict,
)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return outputs
def prepare_inputs_for_generation(

View File

@ -1,12 +1,29 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Fuyu model. """
import io
import unittest
import requests
from transformers import FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import cached_property
from ...test_modeling_common import ids_tensor, random_attention_mask
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_vision_available():
@ -23,19 +40,17 @@ if is_torch_available():
from transformers import FuyuForCausalLM
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Fuyu
class FuyuModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
image_size=300,
patch_size=30,
image_size=30,
patch_size=15,
num_channels=3,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
@ -62,7 +77,6 @@ class FuyuModelTester:
self.num_channels = num_channels
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
@ -88,21 +102,15 @@ class FuyuModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
return config, input_ids, input_mask, sequence_labels, token_labels
def get_config(self):
return FuyuConfig(
@ -122,7 +130,12 @@ class FuyuModelTester:
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self,
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
):
model = FuyuForCausalLM(config=config)
model.to(torch_device)
@ -135,11 +148,9 @@ class FuyuModelTester:
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
@ -165,11 +176,9 @@ class FuyuModelTester:
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
@ -183,11 +192,9 @@ class FuyuModelTester:
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
@ -246,49 +253,73 @@ class FuyuModelTester:
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
@require_torch_accelerator
@slow
class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
"""
Currently, all these tests depend on a value of max_tokens_to_generate of 10.
"""
class FuyuModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": FuyuForCausalLM} if is_torch_available() else {}
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
test_head_masking = False
test_pruning = False
test_cpu_offload = False
test_disk_offload = False
test_model_parallel = False
def setUp(self):
self.pretrained_model_name = "adept/fuyu-8b"
self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name)
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
self.bus_image_url = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
)
self.bus_image_pil = Image.open(io.BytesIO(requests.get(self.bus_image_url).content))
self.model_tester = FuyuModelTester(self)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@slow
@require_torch_gpu
class FuyuModelIntegrationTest(unittest.TestCase):
@cached_property
def default_processor(self):
return FuyuProcessor.from_pretrained("adept/fuyu-8b")
@cached_property
def default_model(self):
return FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
def test_greedy_generation(self):
processor = self.default_processor
model = self.default_model
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
image = Image.open(io.BytesIO(requests.get(url).content))
@slow
def test_model_8b_chat_greedy_generation_bus_captioning(self):
EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|"""
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
end_sequence = text[0].split("\x04")[1]
clean_sequence = (
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
if "|ENDOFTEXT|" in end_sequence
else end_sequence
)
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence[1:])
inputs = processor(text=text_prompt_coco_captioning, images=image, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=10)
# take the last 8 tokens (in order to skip special \n\x04 characters) and decode them
generated_text = processor.batch_decode(generated_ids[:, -8:], skip_special_tokens=True)[0]
self.assertEqual(generated_text, "A blue bus parked on the side of a road.")
"""

View File

@ -6,4 +6,5 @@ docs/source/en/task_summary.md
docs/source/en/tasks/prompting.md
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/ctrl/modeling_ctrl.py
src/transformers/models/kosmos2/modeling_kosmos2.py
src/transformers/models/fuyu/modeling_fuyu.py
src/transformers/models/kosmos2/modeling_kosmos2.py