mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 01:58:22 +06:00

* First draft * Make fixup * Make forward pass worké * Improve code * More improvements * More improvements * Make predictions match * More improvements * Improve image processor * Fix model tests * Add classic decoder * Convert classic decoder * Verify image processor * Fix classic decoder logits * Clean up * Add post_process_pose_estimation * Improve post_process_pose_estimation * Use AutoBackbone * Add support for MoE models * Fix tests, improve num_experts% * Improve variable names * Make fixup * More improvements * Improve post_process_pose_estimation * Compute centers and scales * Improve postprocessing * More improvements * Fix ViTPoseBackbone tests * Add docstrings, fix image processor tests * Update index * Use is_cv2_available * Add model to toctree * Add cv2 to doc tests * Remove script * Improve conversion script * Add coco_to_pascal_voc * Add box_to_center_and_scale to image_transforms * Update tests * Add integration test * Fix merge * Address comments * Replace numpy by pytorch, improve docstrings * Remove get_input_embeddings * Address comments * Move coco_to_pascal_voc * Address comment * Fix style * Address comments * Fix test * Address comment * Remove udp * Remove comment * [WIP] need to check if the numpy function is same as cv * add scipy affine_transform * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * refactor convert * add output_shape * add atol 5e-2 * Use hf_hub_download in conversion script * make box_to_center more applicable * skipt test_get_set_embedding * fix to accept array and fix CI * add co-contributor * make it to tensor type output * add torch * change to torch tensor * add more test * minor change * CI test change * import torch should be above ImageProcessor * make style * try not use torch in def * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fix * fix * add caution * make more detail about dataset_index * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * add docs * Update docs/source/en/model_doc/vitpose.md * Update src/transformers/models/vitpose/configuration_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Revert "Update src/transformers/__init__.py" This reverts commit7ffa504450
. * change name * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vitpose/test_modeling_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vitpose.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * move vitpose only function to image_processor * raise valueerror when using timm backbone * use out_indices * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove camel-case of def flip_back * rename vitposeEstimatorOutput * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix confused camelcase of MLP * remove in-place logic * clear scale description * make consistent batch format * docs update * formatting docstring * add batch tests * test docs change * Update src/transformers/models/vitpose/image_processing_vitpose.py * Update src/transformers/models/vitpose/configuration_vitpose.py * chagne ViT to Vit * change to enable MoE * make fix-copies * Update docs/source/en/model_doc/vitpose.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * extract udp * add more described docs * simple fix * change to accept target_size * make style * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/configuration_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change to `verify_backbone_config_arguments` * Update docs/source/en/model_doc/vitpose.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove unnecessary copy * make config immutable * enable gradient checkpointing * update inappropriate docstring * linting docs * split function for visibility * make style * check isinstances * change to acceptable use_pretrained_backbone * make style * remove copy in docs * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update docs/source/en/model_doc/vitpose.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * simple fix + make style * change input config of activation function to string * Update docs/source/en/model_doc/vitpose.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * tmp docs * delete index.md * make fix-copies * simple fix * change conversion to sam2/mllama style * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * refactor convert * add supervision * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * remove reduntant def * seperate code block for visualization * add validation for num_moe * final commit * add labels * [run-slow] vitpose, vitpose_backbone * Update src/transformers/models/vitpose/convert_vitpose_to_hf.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * enable all conversion * final commit * [run-slow] vitpose, vitpose_backbone * ruff check --fix * [run-slow] vitpose, vitpose_backbone * rename split module * [run-slow] vitpose, vitpose_backbone * fix pos_embed * Simplify init * Revert "fix pos_embed" This reverts commit2c56a4806e
. * refactor single loop * allow flag to enable custom model * efficiency of MoE to not use unused experts * make style * Fix range -> arange to avoid warning * Revert MOE router, a new one does not work * Fix postprocessing a bit (labels) * Fix type hint * Fix docs snippets * Fix links to checkpoints * Fix checkpoints in tests * Fix test * Add image to docs --------- Co-authored-by: Niels Rogge <nielsrogge@nielss-mbp.home> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: sangbumchoi <danielsejong55@gmail.com> Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
200 lines
6.9 KiB
Python
200 lines
6.9 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch VitPose backbone model."""
|
|
|
|
import inspect
|
|
import unittest
|
|
|
|
from transformers import VitPoseBackboneConfig
|
|
from transformers.testing_utils import require_torch
|
|
from transformers.utils import is_torch_available, is_vision_available
|
|
|
|
from ...test_backbone_common import BackboneTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|
|
|
|
|
if is_torch_available():
|
|
from transformers import VitPoseBackbone
|
|
|
|
|
|
if is_vision_available():
|
|
pass
|
|
|
|
|
|
class VitPoseBackboneModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
image_size=[16 * 8, 12 * 8],
|
|
patch_size=[8, 8],
|
|
num_channels=3,
|
|
is_training=True,
|
|
use_labels=True,
|
|
hidden_size=32,
|
|
num_hidden_layers=5,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
type_sequence_label_size=10,
|
|
initializer_range=0.02,
|
|
num_labels=2,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.is_training = is_training
|
|
self.use_labels = use_labels
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.scope = scope
|
|
|
|
# in VitPoseBackbone, the seq length equals the number of patches
|
|
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
|
|
self.seq_length = num_patches
|
|
|
|
def prepare_config_and_inputs(self):
|
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
|
|
|
|
labels = None
|
|
if self.use_labels:
|
|
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, pixel_values, labels
|
|
|
|
def get_config(self):
|
|
return VitPoseBackboneConfig(
|
|
image_size=self.image_size,
|
|
patch_size=self.patch_size,
|
|
num_channels=self.num_channels,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
initializer_range=self.initializer_range,
|
|
num_labels=self.num_labels,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
pixel_values,
|
|
labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"pixel_values": pixel_values}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
|
"""
|
|
Here we also overwrite some of the tests of test_modeling_common.py, as VitPoseBackbone does not use input_ids, inputs_embeds,
|
|
attention_mask and seq_length.
|
|
"""
|
|
|
|
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
|
|
fx_compatible = False
|
|
test_pruning = False
|
|
test_resize_embeddings = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = VitPoseBackboneModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self, config_class=VitPoseBackboneConfig, has_text_modality=False, hidden_size=37
|
|
)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
|
def test_model_common_attributes(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support feedforward chunking")
|
|
def test_feed_forward_chunking(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not output a loss")
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
|
def test_training(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["pixel_values"]
|
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
|
|
|
|
|
@require_torch
|
|
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
|
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
|
|
config_class = VitPoseBackboneConfig
|
|
|
|
has_attentions = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = VitPoseBackboneModelTester(self)
|