mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 12:08:22 +06:00

* tvp model for video grounding add tokenizer auto fix param in TVPProcessor add docs clear comments and enable different torch dtype add image processor test and model test and fix code style * fix conflict * fix model doc * fix image processing tests * fix tvp tests * remove torch in processor * fix grammar error * add more details on tvp.md * fix model arch for loss, grammar, and processor * add docstring and do not regard TvpTransformer, TvpVisionModel as individual model * use pad_image * update copyright * control first downsample stride * reduce first only works for ResNetBottleNeckLayer * fix param name * fix style * add testing * fix style * rm init_weight * fix style * add post init * fix comments * do not test TvpTransformer * fix warning * fix style * fix example * fix config map * add link in config * fix comments * fix style * rm useless param * change attention * change test * add notes * fix comments * fix tvp * import checkpointing * fix gradient checkpointing * Use a more accurate example in readme * update * fix copy * fix style * update readme * delete print * remove tvp test_forward_signature * remove TvpTransformer * fix test init model * merge main and make style * fix tests and others * fix image processor * fix style and model_input_names * fix tests
262 lines
10 KiB
Python
262 lines
10 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The Intel Team Authors, The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch TVP model. """
|
|
|
|
|
|
import unittest
|
|
|
|
from transformers import ResNetConfig, TvpConfig
|
|
from transformers.testing_utils import require_torch, require_vision, torch_device
|
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
|
|
|
from ...test_modeling_common import (
|
|
ModelTesterMixin,
|
|
_config_zero_init,
|
|
floats_tensor,
|
|
ids_tensor,
|
|
random_attention_mask,
|
|
)
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import TvpForVideoGrounding, TvpModel
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
from transformers import TvpImageProcessor
|
|
|
|
|
|
# Copied from test.models.videomae.test_modeling_videomae.VideoMAEModelTester with VideoMAE->TVP
|
|
class TVPModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=1,
|
|
seq_length=2,
|
|
alpha=1.0,
|
|
beta=0.1,
|
|
visual_prompter_type="framepad",
|
|
visual_prompter_apply="replace",
|
|
num_frames=2,
|
|
max_img_size=448,
|
|
visual_prompt_size=96,
|
|
vocab_size=100,
|
|
hidden_size=32,
|
|
intermediate_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
max_position_embeddings=30,
|
|
max_grid_col_position_embeddings=30,
|
|
max_grid_row_position_embeddings=30,
|
|
hidden_dropout_prob=0.1,
|
|
hidden_act="gelu",
|
|
layer_norm_eps=1e-12,
|
|
initializer_range=0.02,
|
|
pad_token_id=0,
|
|
type_vocab_size=2,
|
|
attention_probs_dropout_prob=0.1,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.input_id_length = seq_length
|
|
self.seq_length = seq_length + 10 + 784 # include text prompt length and visual input length
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.visual_prompter_type = visual_prompter_type
|
|
self.visual_prompter_apply = visual_prompter_apply
|
|
self.num_frames = num_frames
|
|
self.max_img_size = max_img_size
|
|
self.visual_prompt_size = visual_prompt_size
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.max_grid_col_position_embeddings = max_grid_col_position_embeddings
|
|
self.max_grid_row_position_embeddings = max_grid_row_position_embeddings
|
|
self.layer_norm_eps = layer_norm_eps
|
|
self.initializer_range = initializer_range
|
|
self.pad_token_id = pad_token_id
|
|
self.type_vocab_size = type_vocab_size
|
|
self.is_training = False
|
|
self.num_channels = 3
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.input_id_length], self.vocab_size)
|
|
attention_mask = random_attention_mask([self.batch_size, self.input_id_length])
|
|
pixel_values = floats_tensor(
|
|
[self.batch_size, self.num_frames, self.num_channels, self.max_img_size, self.max_img_size]
|
|
)
|
|
|
|
config = self.get_config()
|
|
|
|
return (config, input_ids, pixel_values, attention_mask)
|
|
|
|
def get_config(self):
|
|
resnet_config = ResNetConfig(
|
|
num_channels=3,
|
|
embeddings_size=64,
|
|
hidden_sizes=[64, 128],
|
|
depths=[2, 2],
|
|
hidden_act="relu",
|
|
out_features=["stage2"],
|
|
out_indices=[2],
|
|
)
|
|
return TvpConfig(
|
|
backbone_config=resnet_config,
|
|
alpha=self.alpha,
|
|
beta=self.beta,
|
|
visual_prompter_type=self.visual_prompter_type,
|
|
visual_prompter_apply=self.visual_prompter_apply,
|
|
num_frames=self.num_frames,
|
|
max_img_size=self.max_img_size,
|
|
visual_prompt_size=self.visual_prompt_size,
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
max_grid_col_position_embeddings=self.max_grid_col_position_embeddings,
|
|
max_grid_row_position_embeddings=self.max_grid_row_position_embeddings,
|
|
layer_norm_eps=self.layer_norm_eps,
|
|
initializer_range=self.initializer_range,
|
|
pad_token_id=self.pad_token_id,
|
|
type_vocab_size=self.type_vocab_size,
|
|
)
|
|
|
|
def create_and_check_model(self, config, input_ids, pixel_values, attention_mask):
|
|
model = TvpModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, pixel_values, attention_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_ids, pixel_values, attention_mask = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "pixel_values": pixel_values, "attention_mask": attention_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
"""
|
|
Here we also overwrite some of the tests of test_modeling_common.py, as TVP does not use, inputs_embeds.
|
|
The seq_length in TVP contain textual and visual inputs, and prompt.
|
|
"""
|
|
|
|
all_model_classes = (TvpModel, TvpForVideoGrounding) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{"feature-extraction": TvpModel, "temporal-video-grounding": TvpForVideoGrounding}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
|
|
def setUp(self):
|
|
self.model_tester = TVPModelTester(self)
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
@unittest.skip(reason="TVP does not use inputs_embeds")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="TVPModel does not have input/output embeddings")
|
|
def test_model_common_attributes(self):
|
|
pass
|
|
|
|
# override as the `logit_scale` parameter initilization is different for TVP
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
# params are randomly initialized.
|
|
self.assertAlmostEqual(
|
|
param.data.mean().item(),
|
|
0.0,
|
|
delta=1.0,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
|
|
# We will verify our results on an image of cute cats
|
|
def prepare_img():
|
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
|
return image
|
|
|
|
|
|
@require_vision
|
|
@require_torch
|
|
class TvpModelIntegrationTests(unittest.TestCase):
|
|
@cached_property
|
|
def default_image_processor(self):
|
|
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None
|
|
|
|
def test_inference_no_head(self):
|
|
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
|
|
|
image_processor = self.default_image_processor
|
|
image = prepare_img()
|
|
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
|
|
input_ids = torch.tensor([[1, 2]])
|
|
attention_mask = torch.tensor([[1, 1]])
|
|
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**encoding)
|
|
|
|
expected_shape = torch.Size((1, 796, 128))
|
|
assert outputs.last_hidden_state.shape == expected_shape
|
|
expected_slice = torch.tensor(
|
|
[[-0.4902, -0.4121, -1.7872], [-0.2184, 2.1211, -0.9371], [0.1180, 0.5003, -0.1727]]
|
|
).to(torch_device)
|
|
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
|
|
|
def test_inference_with_head(self):
|
|
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
|
|
|
image_processor = self.default_image_processor
|
|
image = prepare_img()
|
|
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
|
|
input_ids = torch.tensor([[1, 2]])
|
|
attention_mask = torch.tensor([[1, 1]])
|
|
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**encoding)
|
|
|
|
expected_shape = torch.Size((1, 2))
|
|
assert outputs.logits.shape == expected_shape
|
|
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
|
|
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))
|