mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Fuyu] Add tests (#27001)
* Add tests * Add integration test * More improvements * Fix tests * Fix style * Skip gradient checkpointing tests * Update script * Remove scripts * Remove Fuyu from auto mapping * Fix integration test * More improvements * Remove file * Add Fuyu to slow documentation tests * Address comments * Clarify comment
This commit is contained in:
parent
186c077513
commit
cc0dc24bc9
@ -19,10 +19,10 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...models.auto.modeling_auto import AutoModelForCausalLM
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_fuyu import FuyuConfig
|
||||
|
||||
|
||||
@ -101,6 +101,11 @@ FUYU_INPUTS_DOCSTRING = r"""
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
|
||||
Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
|
||||
hidden size of the model.
|
||||
image_patches_indices (`torch.LongTensor` of shape `(batch_size, num_total_patches + number_of_newline_tokens + number_of_text_tokens, patch_size_ x patch_size x num_channels )`, *optional*):
|
||||
Indices indicating at which position the image_patches have to be inserted in input_embeds.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
@ -136,17 +141,10 @@ FUYU_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Fuyu Model outputting raw hidden-states without any specific head on top.",
|
||||
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
|
||||
FUYU_START_DOCSTRING,
|
||||
)
|
||||
class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FuyuDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: FuyuConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: FuyuConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -178,12 +176,14 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
embeddings.
|
||||
|
||||
Args:
|
||||
word_embeddings: Tensor of word embeddings. Shape: [b, s, h]
|
||||
continuous_embeddings:
|
||||
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is
|
||||
shape [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
|
||||
indices in image_patch_input_indices for that batch element.
|
||||
image_patch_input_indices: Tensor of indices of the image patches in the input_ids tensor. Shape: [b, s]
|
||||
word_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Tensor of word embeddings.
|
||||
continuous_embeddings (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
|
||||
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is shape
|
||||
[num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
|
||||
indices in image_patch_input_indices for that batch element.
|
||||
image_patch_input_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Tensor of indices of the image patches in the input_ids tensor.
|
||||
"""
|
||||
if not (word_embeddings.shape[0] == len(continuous_embeddings)):
|
||||
raise ValueError(
|
||||
@ -208,6 +208,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
return output_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -218,10 +219,42 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
|
||||
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> prompt = "Generate a coco-style caption.\n"
|
||||
|
||||
>>> inputs = processor(text=text_prompt, images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=7)
|
||||
>>> generation_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
>>> print(generation_text)
|
||||
'A bus parked on the side of a road.'
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -230,15 +263,14 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
raise ValueError("You have to specify either input_is or inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
@ -273,10 +305,12 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
@ -1,12 +1,29 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Fuyu model. """
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import FuyuConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_modeling_common import ids_tensor, random_attention_mask
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -23,19 +40,17 @@ if is_torch_available():
|
||||
from transformers import FuyuForCausalLM
|
||||
|
||||
|
||||
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Fuyu
|
||||
class FuyuModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
image_size=300,
|
||||
patch_size=30,
|
||||
image_size=30,
|
||||
patch_size=15,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
@ -62,7 +77,6 @@ class FuyuModelTester:
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
@ -88,21 +102,15 @@ class FuyuModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
return config, input_ids, input_mask, sequence_labels, token_labels
|
||||
|
||||
def get_config(self):
|
||||
return FuyuConfig(
|
||||
@ -122,7 +130,12 @@ class FuyuModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
):
|
||||
model = FuyuForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
@ -135,11 +148,9 @@ class FuyuModelTester:
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
@ -165,11 +176,9 @@ class FuyuModelTester:
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
@ -183,11 +192,9 @@ class FuyuModelTester:
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
@ -246,49 +253,73 @@ class FuyuModelTester:
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
"""
|
||||
Currently, all these tests depend on a value of max_tokens_to_generate of 10.
|
||||
"""
|
||||
class FuyuModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-to-text": FuyuForCausalLM} if is_torch_available() else {}
|
||||
|
||||
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_cpu_offload = False
|
||||
test_disk_offload = False
|
||||
test_model_parallel = False
|
||||
|
||||
def setUp(self):
|
||||
self.pretrained_model_name = "adept/fuyu-8b"
|
||||
self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name)
|
||||
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
|
||||
self.bus_image_url = (
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
)
|
||||
self.bus_image_pil = Image.open(io.BytesIO(requests.get(self.bus_image_url).content))
|
||||
self.model_tester = FuyuModelTester(self)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class FuyuModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return FuyuProcessor.from_pretrained("adept/fuyu-8b")
|
||||
|
||||
@cached_property
|
||||
def default_model(self):
|
||||
return FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
|
||||
|
||||
def test_greedy_generation(self):
|
||||
processor = self.default_processor
|
||||
model = self.default_model
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
image = Image.open(io.BytesIO(requests.get(url).content))
|
||||
|
||||
@slow
|
||||
def test_model_8b_chat_greedy_generation_bus_captioning(self):
|
||||
EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|"""
|
||||
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
|
||||
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
|
||||
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
|
||||
text = self.processor.tokenizer.batch_decode(generated_tokens)
|
||||
end_sequence = text[0].split("\x04")[1]
|
||||
clean_sequence = (
|
||||
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
|
||||
if "|ENDOFTEXT|" in end_sequence
|
||||
else end_sequence
|
||||
)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence[1:])
|
||||
|
||||
inputs = processor(text=text_prompt_coco_captioning, images=image, return_tensors="pt")
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=10)
|
||||
|
||||
# take the last 8 tokens (in order to skip special \n\x04 characters) and decode them
|
||||
generated_text = processor.batch_decode(generated_ids[:, -8:], skip_special_tokens=True)[0]
|
||||
self.assertEqual(generated_text, "A blue bus parked on the side of a road.")
|
||||
|
||||
|
||||
"""
|
||||
|
@ -6,4 +6,5 @@ docs/source/en/task_summary.md
|
||||
docs/source/en/tasks/prompting.md
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
src/transformers/models/kosmos2/modeling_kosmos2.py
|
||||
src/transformers/models/fuyu/modeling_fuyu.py
|
||||
src/transformers/models/kosmos2/modeling_kosmos2.py
|
Loading…
Reference in New Issue
Block a user