Add ViTPose (#30530)

* First draft

* Make fixup

* Make forward pass worké

* Improve code

* More improvements

* More improvements

* Make predictions match

* More improvements

* Improve image processor

* Fix model tests

* Add classic decoder

* Convert classic decoder

* Verify image processor

* Fix classic decoder logits

* Clean up

* Add post_process_pose_estimation

* Improve post_process_pose_estimation

* Use AutoBackbone

* Add support for MoE models

* Fix tests, improve num_experts%

* Improve variable names

* Make fixup

* More improvements

* Improve post_process_pose_estimation

* Compute centers and scales

* Improve postprocessing

* More improvements

* Fix ViTPoseBackbone tests

* Add docstrings, fix image processor tests

* Update index

* Use is_cv2_available

* Add model to toctree

* Add cv2 to doc tests

* Remove script

* Improve conversion script

* Add coco_to_pascal_voc

* Add box_to_center_and_scale to image_transforms

* Update tests

* Add integration test

* Fix merge

* Address comments

* Replace numpy by pytorch, improve docstrings

* Remove get_input_embeddings

* Address comments

* Move coco_to_pascal_voc

* Address comment

* Fix style

* Address comments

* Fix test

* Address comment

* Remove udp

* Remove comment

* [WIP] need to check if the numpy function is same as cv

* add scipy affine_transform

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* refactor convert

* add output_shape

* add atol 5e-2

* Use hf_hub_download in conversion script

* make box_to_center more applicable

* skipt test_get_set_embedding

* fix to accept array and fix CI

* add co-contributor

* make it to tensor type output

* add torch

* change to torch tensor

* add more test

* minor change

* CI test change

* import torch should be above ImageProcessor

* make style

* try not use torch in def

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vitpose/modeling_vitpose.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fix

* fix

* add caution

* make more detail about dataset_index

* Update src/transformers/models/vitpose/modeling_vitpose.py

Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com>

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com>

* add docs

* Update docs/source/en/model_doc/vitpose.md

* Update src/transformers/models/vitpose/configuration_vitpose.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Revert "Update src/transformers/__init__.py"

This reverts commit 7ffa504450.

* change name

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/vitpose/test_modeling_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/vitpose.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vitpose/modeling_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* move vitpose only function to image_processor

* raise valueerror when using timm backbone

* use out_indices

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove camel-case of def flip_back

* rename vitposeEstimatorOutput

* Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix confused camelcase of MLP

* remove in-place logic

* clear scale description

* make consistent batch format

* docs update

* formatting docstring

* add batch tests

* test docs change

* Update src/transformers/models/vitpose/image_processing_vitpose.py

* Update src/transformers/models/vitpose/configuration_vitpose.py

* chagne ViT to Vit

* change to enable MoE

* make fix-copies

* Update docs/source/en/model_doc/vitpose.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* extract udp

* add more described docs

* simple fix

* change to accept target_size

* make style

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vitpose/configuration_vitpose.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change to `verify_backbone_config_arguments`

* Update docs/source/en/model_doc/vitpose.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove unnecessary copy

* make config immutable

* enable gradient checkpointing

* update inappropriate docstring

* linting docs

* split function for visibility

* make style

* check isinstances

* change to acceptable use_pretrained_backbone

* make style

* remove copy in docs

* Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update docs/source/en/model_doc/vitpose.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/vitpose/modeling_vitpose.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* simple fix + make style

* change input config of activation function to string

* Update docs/source/en/model_doc/vitpose.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* tmp docs

* delete index.md

* make fix-copies

* simple fix

* change conversion to sam2/mllama style

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/vitpose/image_processing_vitpose.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* refactor convert

* add supervision

* Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* remove reduntant def

* seperate code block for visualization

* add validation for num_moe

* final commit

* add labels

* [run-slow] vitpose, vitpose_backbone

* Update src/transformers/models/vitpose/convert_vitpose_to_hf.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* enable all conversion

* final commit

* [run-slow] vitpose, vitpose_backbone

* ruff check --fix

* [run-slow] vitpose, vitpose_backbone

* rename split module

* [run-slow] vitpose, vitpose_backbone

* fix pos_embed

* Simplify init

* Revert "fix pos_embed"

This reverts commit 2c56a4806e.

* refactor single loop

* allow flag to enable custom model

* efficiency of MoE to not use unused experts

* make style

* Fix range -> arange to avoid warning

* Revert MOE router, a new one does not work

* Fix postprocessing a bit (labels)

* Fix type hint

* Fix docs snippets

* Fix links to checkpoints

* Fix checkpoints in tests

* Fix test

* Add image to docs

---------

Co-authored-by: Niels Rogge <nielsrogge@nielss-mbp.home>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: sangbumchoi <danielsejong55@gmail.com>
Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
NielsRogge 2025-01-08 17:02:14 +01:00 committed by GitHub
parent 4349a0e401
commit 8490d3159c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 3350 additions and 0 deletions

View File

@ -741,6 +741,8 @@
title: ViTMatte
- local: model_doc/vit_msn
title: ViTMSN
- local: model_doc/vitpose
title: ViTPose
- local: model_doc/yolos
title: YOLOS
- local: model_doc/zoedepth

View File

@ -356,6 +356,8 @@ Flax), PyTorch, and/or TensorFlow.
| [ViTMAE](model_doc/vit_mae) | ✅ | ✅ | ❌ |
| [ViTMatte](model_doc/vitmatte) | ✅ | ❌ | ❌ |
| [ViTMSN](model_doc/vit_msn) | ✅ | ❌ | ❌ |
| [VitPose](model_doc/vitpose) | ✅ | ❌ | ❌ |
| [VitPoseBackbone](model_doc/vitpose_backbone) | ✅ | ❌ | ❌ |
| [VITS](model_doc/vits) | ✅ | ❌ | ❌ |
| [ViViT](model_doc/vivit) | ✅ | ❌ | ❌ |
| [Wav2Vec2](model_doc/wav2vec2) | ✅ | ✅ | ✅ |

View File

@ -0,0 +1,254 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VitPose
## Overview
The VitPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://arxiv.org/abs/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. VitPose employs a standard, non-hierarchical [Vision Transformer](https://arxiv.org/pdf/2010.11929v2) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark.
The abstract from the paper is the following:
*Although no specific domain knowledge is considered in the design, plain vision transformers have shown excellent performance in visual recognition tasks. However, little effort has been made to reveal the potential of such simple structures for pose estimation tasks. In this paper, we show the surprisingly good capabilities of plain vision transformers for pose estimation from various aspects, namely simplicity in model structure, scalability in model size, flexibility in training paradigm, and transferability of knowledge between models, through a simple baseline model called ViTPose. Specifically, ViTPose employs plain and non-hierarchical vision transformers as backbones to extract features for a given person instance and a lightweight decoder for pose estimation. It can be scaled up from 100M to 1B parameters by taking the advantages of the scalable model capacity and high parallelism of transformers, setting a new Pareto front between throughput and performance. Besides, ViTPose is very flexible regarding the attention type, input resolution, pre-training and finetuning strategy, as well as dealing with multiple pose tasks. We also empirically demonstrate that the knowledge of large ViTPose models can be easily transferred to small ones via a simple knowledge token. Experimental results show that our basic ViTPose model outperforms representative methods on the challenging MS COCO Keypoint Detection benchmark, while the largest model sets a new state-of-the-art.*
![vitpose-architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-architecture.png)
This model was contributed by [nielsr](https://huggingface.co/nielsr) and [sangbumchoi](https://github.com/SangbumChoi).
The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPose).
## Usage Tips
ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.
```py
import torch
import requests
import numpy as np
from PIL import Image
from transformers import (
AutoProcessor,
RTDetrForObjectDetection,
VitPoseForPoseEstimation,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# ------------------------------------------------------------------------
# Stage 1. Detect humans on the image
# ------------------------------------------------------------------------
# You can choose detector by your choice
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = person_model(**inputs)
results = person_image_processor.post_process_object_detection(
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
)
result = results[0] # take first image results
# Human label refers 0 index in COCO dataset
person_boxes = result["boxes"][result["labels"] == 0]
person_boxes = person_boxes.cpu().numpy()
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
# ------------------------------------------------------------------------
# Stage 2. Detect keypoints for each person found
# ------------------------------------------------------------------------
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
image_pose_result = pose_results[0] # results for first image
```
### Visualization for supervision user
```py
import supervision as sv
xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()
key_points = sv.KeyPoints(
xy=xy, confidence=scores
)
edge_annotator = sv.EdgeAnnotator(
color=sv.Color.GREEN,
thickness=1
)
vertex_annotator = sv.VertexAnnotator(
color=sv.Color.RED,
radius=2
)
annotated_frame = edge_annotator.annotate(
scene=image.copy(),
key_points=key_points
)
annotated_frame = vertex_annotator.annotate(
scene=annotated_frame,
key_points=key_points
)
```
### Visualization for advanced user
```py
import math
import cv2
def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
if pose_keypoint_color is not None:
assert len(pose_keypoint_color) == len(keypoints)
for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
x_coord, y_coord = int(kpt[0]), int(kpt[1])
if kpt_score > keypoint_score_threshold:
color = tuple(int(c) for c in pose_keypoint_color[kid])
if show_keypoint_weight:
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
transparency = max(0, min(1, kpt_score))
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
else:
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
height, width, _ = image.shape
if keypoint_edges is not None and link_colors is not None:
assert len(link_colors) == len(keypoint_edges)
for sk_id, sk in enumerate(keypoint_edges):
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
if (
x1 > 0
and x1 < width
and y1 > 0
and y1 < height
and x2 > 0
and x2 < width
and y2 > 0
and y2 < height
and score1 > keypoint_score_threshold
and score2 > keypoint_score_threshold
):
color = tuple(int(c) for c in link_colors[sk_id])
if show_keypoint_weight:
X = (x1, x2)
Y = (y1, y2)
mean_x = np.mean(X)
mean_y = np.mean(Y)
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
polygon = cv2.ellipse2Poly(
(int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
)
cv2.fillConvexPoly(image, polygon, color)
transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
else:
cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
# Note: keypoint_edges and color palette are dataset-specific
keypoint_edges = model.config.edges
palette = np.array(
[
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
]
)
link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
numpy_image = np.array(image)
for pose_result in image_pose_result:
scores = np.array(pose_result["scores"])
keypoints = np.array(pose_result["keypoints"])
# draw each point on image
draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
# draw links
draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
pose_image = Image.fromarray(numpy_image)
pose_image
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>
### MoE backbone
To enable MoE (Mixture of Experts) function in the backbone, user has to give appropriate configuration such as `num_experts` and input value `dataset_index` to the backbone model. However, it is not used in default parameters. Below is the code snippet for usage of MoE function.
```py
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
>>> import torch
>>> config = VitPoseBackboneConfig(num_experts=3, out_indices=[-1])
>>> model = VitPoseBackbone(config)
>>> pixel_values = torch.randn(3, 3, 256, 192)
>>> dataset_index = torch.tensor([1, 2, 3])
>>> outputs = model(pixel_values, dataset_index)
```
## VitPoseImageProcessor
[[autodoc]] VitPoseImageProcessor
- preprocess
## VitPoseConfig
[[autodoc]] VitPoseConfig
## VitPoseForPoseEstimation
[[autodoc]] VitPoseForPoseEstimation
- forward

View File

@ -834,6 +834,8 @@ _import_structure = {
"models.vit_msn": ["ViTMSNConfig"],
"models.vitdet": ["VitDetConfig"],
"models.vitmatte": ["VitMatteConfig"],
"models.vitpose": ["VitPoseConfig"],
"models.vitpose_backbone": ["VitPoseBackboneConfig"],
"models.vits": [
"VitsConfig",
"VitsTokenizer",
@ -1266,6 +1268,7 @@ else:
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
_import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vitpose"].append("VitPoseImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
_import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
@ -3755,6 +3758,18 @@ else:
"VitMattePreTrainedModel",
]
)
_import_structure["models.vitpose"].extend(
[
"VitPoseForPoseEstimation",
"VitPosePreTrainedModel",
]
)
_import_structure["models.vitpose_backbone"].extend(
[
"VitPoseBackbone",
"VitPoseBackbonePreTrainedModel",
]
)
_import_structure["models.vits"].extend(
[
"VitsModel",
@ -5877,6 +5892,8 @@ if TYPE_CHECKING:
from .models.vit_msn import ViTMSNConfig
from .models.vitdet import VitDetConfig
from .models.vitmatte import VitMatteConfig
from .models.vitpose import VitPoseConfig
from .models.vitpose_backbone import VitPoseBackboneConfig
from .models.vits import (
VitsConfig,
VitsTokenizer,
@ -6311,6 +6328,7 @@ if TYPE_CHECKING:
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
from .models.vit import ViTFeatureExtractor, ViTImageProcessor
from .models.vitmatte import VitMatteImageProcessor
from .models.vitpose import VitPoseImageProcessor
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
from .models.zoedepth import ZoeDepthImageProcessor
@ -8294,6 +8312,11 @@ if TYPE_CHECKING:
VitMatteForImageMatting,
VitMattePreTrainedModel,
)
from .models.vitpose import (
VitPoseForPoseEstimation,
VitPosePreTrainedModel,
)
from .models.vitpose_backbone import VitPoseBackbone, VitPoseBackbonePreTrainedModel
from .models.vits import (
VitsModel,
VitsPreTrainedModel,

View File

@ -277,6 +277,8 @@ from . import (
vit_msn,
vitdet,
vitmatte,
vitpose,
vitpose_backbone,
vits,
vivit,
wav2vec2,

View File

@ -309,6 +309,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("vit_msn", "ViTMSNConfig"),
("vitdet", "VitDetConfig"),
("vitmatte", "VitMatteConfig"),
("vitpose", "VitPoseConfig"),
("vitpose_backbone", "VitPoseBackboneConfig"),
("vits", "VitsConfig"),
("vivit", "VivitConfig"),
("wav2vec2", "Wav2Vec2Config"),
@ -642,6 +644,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("vit_msn", "ViTMSN"),
("vitdet", "VitDet"),
("vitmatte", "ViTMatte"),
("vitpose", "VitPose"),
("vitpose_backbone", "VitPoseBackbone"),
("vits", "VITS"),
("vivit", "ViViT"),
("wav2vec2", "Wav2Vec2"),

View File

@ -1396,6 +1396,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("textnet", "TextNetBackbone"),
("timm_backbone", "TimmBackbone"),
("vitdet", "VitDetBackbone"),
("vitpose_backbone", "VitPoseBackbone"),
]
)

View File

@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_vitpose import *
from .image_processing_vitpose import *
from .modeling_vitpose import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VitPose model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class VitPoseConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VitPoseForPoseEstimation`]. It is used to instantiate a
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VitPose
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_factor (`int`, *optional*, defaults to 4):
Factor to upscale the feature maps coming from the ViT backbone.
use_simple_decoder (`bool`, *optional*, defaults to `True`):
Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`.
Example:
```python
>>> from transformers import VitPoseConfig, VitPoseForPoseEstimation
>>> # Initializing a VitPose configuration
>>> configuration = VitPoseConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = VitPoseForPoseEstimation(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "vitpose"
def __init__(
self,
backbone_config: PretrainedConfig = None,
backbone: str = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
backbone_kwargs: dict = None,
initializer_range: float = 0.02,
scale_factor: int = 4,
use_simple_decoder: bool = True,
**kwargs,
):
super().__init__(**kwargs)
if use_pretrained_backbone:
logger.info(
"`use_pretrained_backbone` is `True`. For the pure inference purpose of VitPose weight do not set this value."
)
if use_timm_backbone:
raise ValueError("use_timm_backbone set `True` is not supported at the moment.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitPose` backbone.")
backbone_config = CONFIG_MAPPING["vitpose_backbone"](out_indices=[4])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.initializer_range = initializer_range
self.scale_factor = scale_factor
self.use_simple_decoder = use_simple_decoder
__all__ = ["VitPoseConfig"]

View File

@ -0,0 +1,355 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert VitPose checkpoints from the original repository.
URL: https://github.com/vitae-transformer/vitpose
"""
import argparse
import os
import re
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"patch_embed.proj": "embeddings.patch_embeddings.projection",
r"pos_embed": "embeddings.position_embeddings",
r"blocks": "encoder.layer",
r"attn.proj": "attention.output.dense",
r"attn": "attention.self",
r"norm1": "layernorm_before",
r"norm2": "layernorm_after",
r"last_norm": "layernorm",
r"keypoint_head": "head",
r"final_layer": "conv",
}
MODEL_TO_FILE_NAME_MAPPING = {
"vitpose-base-simple": "vitpose-b-simple.pth",
"vitpose-base": "vitpose-b.pth",
"vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth",
"vitpose-plus-base": "vitpose+_base.pth",
}
def get_config(model_name):
num_experts = 6 if "plus" in model_name else 1
part_features = 192 if "plus" in model_name else 0
backbone_config = VitPoseBackboneConfig(out_indices=[12], num_experts=num_experts, part_features=part_features)
# size of the architecture
if "small" in model_name:
backbone_config.hidden_size = 768
backbone_config.intermediate_size = 2304
backbone_config.num_hidden_layers = 8
backbone_config.num_attention_heads = 8
elif "large" in model_name:
backbone_config.hidden_size = 1024
backbone_config.intermediate_size = 4096
backbone_config.num_hidden_layers = 24
backbone_config.num_attention_heads = 16
elif "huge" in model_name:
backbone_config.hidden_size = 1280
backbone_config.intermediate_size = 5120
backbone_config.num_hidden_layers = 32
backbone_config.num_attention_heads = 16
use_simple_decoder = "simple" in model_name
edges = [
[15, 13],
[13, 11],
[16, 14],
[14, 12],
[11, 12],
[5, 11],
[6, 12],
[5, 6],
[5, 7],
[6, 8],
[7, 9],
[8, 10],
[1, 2],
[0, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
]
id2label = {
0: "Nose",
1: "L_Eye",
2: "R_Eye",
3: "L_Ear",
4: "R_Ear",
5: "L_Shoulder",
6: "R_Shoulder",
7: "L_Elbow",
8: "R_Elbow",
9: "L_Wrist",
10: "R_Wrist",
11: "L_Hip",
12: "R_Hip",
13: "L_Knee",
14: "R_Knee",
15: "L_Ankle",
16: "R_Ankle",
}
label2id = {v: k for k, v in id2label.items()}
config = VitPoseConfig(
backbone_config=backbone_config,
num_labels=17,
use_simple_decoder=use_simple_decoder,
edges=edges,
id2label=id2label,
label2id=label2id,
)
return config
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
# We will verify our results on a COCO image
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
@torch.no_grad()
def write_model(model_path, model_name, push_to_hub, check_logits=True):
os.makedirs(model_path, exist_ok=True)
# ------------------------------------------------------------
# Vision model params and config
# ------------------------------------------------------------
# params from config
config = get_config(model_name)
# ------------------------------------------------------------
# Convert weights
# ------------------------------------------------------------
# load original state_dict
filename = MODEL_TO_FILE_NAME_MAPPING[model_name]
print(f"Fetching all parameters from the checkpoint at {filename}...")
checkpoint_path = hf_hub_download(
repo_id="nielsr/vitpose-original-checkpoints", filename=filename, repo_type="model"
)
print("Converting model...")
original_state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
all_keys = list(original_state_dict.keys())
new_keys = convert_old_keys_to_new_keys(all_keys)
dim = config.backbone_config.hidden_size
state_dict = {}
for key in all_keys:
new_key = new_keys[key]
value = original_state_dict[key]
if re.search("associate_heads", new_key) or re.search("backbone.cls_token", new_key):
# This associated_heads is concept of auxiliary head so does not require in inference stage.
# backbone.cls_token is optional forward function for dynamically change of size, see detail in https://github.com/ViTAE-Transformer/ViTPose/issues/34
pass
elif re.search("qkv", new_key):
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]
elif re.search("head", new_key) and not config.use_simple_decoder:
# Pattern for deconvolution layers
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1))//3 + 1}.weight", new_key)
# Pattern for batch normalization layers
bn_patterns = [
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
]
for pattern, replacement in bn_patterns:
if re.search(pattern, new_key):
# Convert the layer number to the correct batch norm index
layer_num = int(re.search(pattern, key).group(1))
bn_num = layer_num // 3 + 1
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
state_dict[new_key] = value
else:
state_dict[new_key] = value
print("Loading the checkpoint in a Vitpose model.")
model = VitPoseForPoseEstimation(config)
model.eval()
model.load_state_dict(state_dict)
print("Checkpoint loaded successfully.")
# create image processor
image_processor = VitPoseImageProcessor()
# verify image processor
image = prepare_img()
boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
pixel_values = image_processor(images=image, boxes=boxes, return_tensors="pt").pixel_values
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset")
original_pixel_values = torch.load(filepath, map_location="cpu")["img"]
assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1)
dataset_index = torch.tensor([0])
with torch.no_grad():
# first forward pass
outputs = model(pixel_values, dataset_index=dataset_index)
output_heatmap = outputs.heatmaps
# second forward pass (flipped)
# this is done since the model uses `flip_test=True` in its test config
pixel_values_flipped = torch.flip(pixel_values, [3])
outputs_flipped = model(
pixel_values_flipped,
dataset_index=dataset_index,
flip_pairs=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]),
)
output_flipped_heatmap = outputs_flipped.heatmaps
outputs.heatmaps = (output_heatmap + output_flipped_heatmap) * 0.5
# Verify pose_results
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
if check_logits:
if model_name == "vitpose-base-simple":
assert torch.allclose(
pose_results[1]["keypoints"][0],
torch.tensor([3.98180511e02, 1.81808380e02]),
atol=5e-2,
)
assert torch.allclose(
pose_results[1]["scores"][0],
torch.tensor([8.66642594e-01]),
atol=5e-2,
)
elif model_name == "vitpose-base":
assert torch.allclose(
pose_results[1]["keypoints"][0],
torch.tensor([3.9807913e02, 1.8182812e02]),
atol=5e-2,
)
assert torch.allclose(
pose_results[1]["scores"][0],
torch.tensor([8.8235235e-01]),
atol=5e-2,
)
elif model_name == "vitpose-base-coco-aic-mpii":
assert torch.allclose(
pose_results[1]["keypoints"][0],
torch.tensor([3.98305542e02, 1.81741592e02]),
atol=5e-2,
)
assert torch.allclose(
pose_results[1]["scores"][0],
torch.tensor([8.69966745e-01]),
atol=5e-2,
)
elif model_name == "vitpose-plus-base":
assert torch.allclose(
pose_results[1]["keypoints"][0],
torch.tensor([3.98201294e02, 1.81728302e02]),
atol=5e-2,
)
assert torch.allclose(
pose_results[1]["scores"][0],
torch.tensor([8.75046968e-01]),
atol=5e-2,
)
else:
raise ValueError("Model not supported")
print("Conversion successfully done.")
# save the model to a local directory
model.save_pretrained(model_path)
image_processor.save_pretrained(model_path)
if push_to_hub:
print(f"Pushing model and image processor for {model_name} to hub")
model.push_to_hub(f"danelcsb/{model_name}")
image_processor.push_to_hub(f"danelcsb/{model_name}")
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="vitpose-base-simple",
choices=MODEL_TO_FILE_NAME_MAPPING.keys(),
type=str,
help="Name of the VitPose model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
help="Whether to check the logits of public converted model to the 🤗 hub. You can disable when using custom model.",
)
args = parser.parse_args()
write_model(
model_path=args.pytorch_dump_folder_path,
model_name=args.model_name,
push_to_hub=args.push_to_hub,
check_logits=args.check_logits,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,684 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for VitPose."""
import itertools
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import to_channel_dimension_format
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_scipy_available, is_torch_available, is_vision_available, logging
if is_torch_available():
import torch
if is_vision_available():
import PIL
if is_scipy_available():
from scipy.linalg import inv
from scipy.ndimage import affine_transform, gaussian_filter
if TYPE_CHECKING:
from .modeling_vitpose import VitPoseEstimatorOutput
logger = logging.get_logger(__name__)
# inspired by https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py#L132
def box_to_center_and_scale(
box: Union[Tuple, List, np.ndarray],
image_width: int,
image_height: int,
normalize_factor: float = 200.0,
padding_factor: float = 1.25,
):
"""
Encodes a bounding box in COCO format into (center, scale).
Args:
box (`Tuple`, `List`, or `np.ndarray`):
Bounding box in COCO format (top_left_x, top_left_y, width, height).
image_width (`int`):
Image width.
image_height (`int`):
Image height.
normalize_factor (`float`):
Width and height scale factor.
padding_factor (`float`):
Bounding box padding factor.
Returns:
tuple: A tuple containing center and scale.
- `np.ndarray` [float32](2,): Center of the bbox (x, y).
- `np.ndarray` [float32](2,): Scale of the bbox width & height.
"""
top_left_x, top_left_y, width, height = box[:4]
aspect_ratio = image_width / image_height
center = np.array([top_left_x + width * 0.5, top_left_y + height * 0.5], dtype=np.float32)
if width > aspect_ratio * height:
height = width * 1.0 / aspect_ratio
elif width < aspect_ratio * height:
width = height * aspect_ratio
scale = np.array([width / normalize_factor, height / normalize_factor], dtype=np.float32)
scale = scale * padding_factor
return center, scale
def coco_to_pascal_voc(bboxes: np.ndarray) -> np.ndarray:
"""
Converts bounding boxes from the COCO format to the Pascal VOC format.
In other words, converts from (top_left_x, top_left_y, width, height) format
to (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
Args:
bboxes (`np.ndarray` of shape `(batch_size, 4)):
Bounding boxes in COCO format.
Returns:
`np.ndarray` of shape `(batch_size, 4) in Pascal VOC format.
"""
bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0] - 1
bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1] - 1
return bboxes
def get_keypoint_predictions(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Get keypoint predictions from score maps.
Args:
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
Model predicted heatmaps.
Returns:
tuple: A tuple containing aggregated results.
- coords (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
Predicted keypoint location.
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
Scores (confidence) of the keypoints.
"""
if not isinstance(heatmaps, np.ndarray):
raise ValueError("Heatmaps should be np.ndarray")
if heatmaps.ndim != 4:
raise ValueError("Heatmaps should be 4-dimensional")
batch_size, num_keypoints, _, width = heatmaps.shape
heatmaps_reshaped = heatmaps.reshape((batch_size, num_keypoints, -1))
idx = np.argmax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
scores = np.amax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = preds[:, :, 0] % width
preds[:, :, 1] = preds[:, :, 1] // width
preds = np.where(np.tile(scores, (1, 1, 2)) > 0.0, preds, -1)
return preds, scores
def post_dark_unbiased_data_processing(coords: np.ndarray, batch_heatmaps: np.ndarray, kernel: int = 3) -> np.ndarray:
"""DARK post-pocessing. Implemented by unbiased_data_processing.
Paper references:
- Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
- Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020).
Args:
coords (`np.ndarray` of shape `(num_persons, num_keypoints, 2)`):
Initial coordinates of human pose.
batch_heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
Batched heatmaps as predicted by the model.
A batch_size of 1 is used for the bottom up paradigm where all persons share the same heatmap.
A batch_size of `num_persons` is used for the top down paradigm where each person has its own heatmaps.
kernel (`int`, *optional*, defaults to 3):
Gaussian kernel size (K) for modulation.
Returns:
`np.ndarray` of shape `(num_persons, num_keypoints, 2)` ):
Refined coordinates.
"""
batch_size, num_keypoints, height, width = batch_heatmaps.shape
num_coords = coords.shape[0]
if not (batch_size == 1 or batch_size == num_coords):
raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.")
radius = int((kernel - 1) // 2)
batch_heatmaps = np.array(
[
[gaussian_filter(heatmap, sigma=0.8, radius=(radius, radius), axes=(0, 1)) for heatmap in heatmaps]
for heatmaps in batch_heatmaps
]
)
batch_heatmaps = np.clip(batch_heatmaps, 0.001, 50)
batch_heatmaps = np.log(batch_heatmaps)
batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten()
# calculate indices for coordinates
index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (width + 2)
index += (width + 2) * (height + 2) * np.arange(0, batch_size * num_keypoints).reshape(-1, num_keypoints)
index = index.astype(int).reshape(-1, 1)
i_ = batch_heatmaps_pad[index]
ix1 = batch_heatmaps_pad[index + 1]
iy1 = batch_heatmaps_pad[index + width + 2]
ix1y1 = batch_heatmaps_pad[index + width + 3]
ix1_y1_ = batch_heatmaps_pad[index - width - 3]
ix1_ = batch_heatmaps_pad[index - 1]
iy1_ = batch_heatmaps_pad[index - 2 - width]
# calculate refined coordinates using Newton's method
dx = 0.5 * (ix1 - ix1_)
dy = 0.5 * (iy1 - iy1_)
derivative = np.concatenate([dx, dy], axis=1)
derivative = derivative.reshape(num_coords, num_keypoints, 2, 1)
dxx = ix1 - 2 * i_ + ix1_
dyy = iy1 - 2 * i_ + iy1_
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
hessian = hessian.reshape(num_coords, num_keypoints, 2, 2)
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze()
return coords
def transform_preds(coords: np.ndarray, center: np.ndarray, scale: np.ndarray, output_size: np.ndarray) -> np.ndarray:
"""Get final keypoint predictions from heatmaps and apply scaling and
translation to map them back to the image.
Note:
num_keypoints: K
Args:
coords (`np.ndarray` of shape `(num_keypoints, ndims)`):
* If ndims=2, corrds are predicted keypoint location.
* If ndims=4, corrds are composed of (x, y, scores, tags)
* If ndims=5, corrds are composed of (x, y, scores, tags,
flipped_tags)
center (`np.ndarray` of shape `(2,)`):
Center of the bounding box (x, y).
scale (`np.ndarray` of shape `(2,)`):
Scale of the bounding box wrt original image of width and height.
output_size (`np.ndarray` of shape `(2,)`):
Size of the destination heatmaps in (height, width) format.
Returns:
np.ndarray: Predicted coordinates in the images.
"""
if coords.shape[1] not in (2, 4, 5):
raise ValueError("Coordinates need to have either 2, 4 or 5 dimensions.")
if len(center) != 2:
raise ValueError("Center needs to have 2 elements, one for x and one for y.")
if len(scale) != 2:
raise ValueError("Scale needs to consist of a width and height")
if len(output_size) != 2:
raise ValueError("Output size needs to consist of a height and width")
# Recover the scale which is normalized by a factor of 200.
scale = scale * 200.0
# We use unbiased data processing
scale_y = scale[1] / (output_size[0] - 1.0)
scale_x = scale[0] / (output_size[1] - 1.0)
target_coords = np.ones_like(coords)
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
return target_coords
def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray, size_target: np.ndarray):
"""
Calculate the transformation matrix under the constraint of unbiased. Paper ref: Huang et al. The Devil is in the
Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Source: https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Args:
theta (`float`):
Rotation angle in degrees.
size_input (`np.ndarray`):
Size of input image [width, height].
size_dst (`np.ndarray`):
Size of output image [width, height].
size_target (`np.ndarray`):
Size of ROI in input plane [w, h].
Returns:
`np.ndarray`: A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = math.cos(theta) * scale_x
matrix[0, 1] = -math.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * math.sin(theta) + 0.5 * size_target[0]
)
matrix[1, 0] = math.sin(theta) * scale_y
matrix[1, 1] = math.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * math.cos(theta) + 0.5 * size_target[1]
)
return matrix
def scipy_warp_affine(src, M, size):
"""
This function implements cv2.warpAffine function using affine_transform in scipy. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html and https://docs.opencv.org/4.x/d4/d61/tutorial_warp_affine.html for more details.
Note: the original implementation of cv2.warpAffine uses cv2.INTER_LINEAR.
"""
channels = [src[..., i] for i in range(src.shape[-1])]
# Convert to a 3x3 matrix used by SciPy
M_scipy = np.vstack([M, [0, 0, 1]])
# If you have a matrix for the push transformation, use its inverse (numpy.linalg.inv) in this function.
M_inv = inv(M_scipy)
M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = (
M_inv[1, 1],
M_inv[1, 0],
M_inv[0, 1],
M_inv[0, 0],
M_inv[1, 2],
M_inv[0, 2],
)
new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels]
new_src = np.stack(new_src, axis=-1)
return new_src
class VitPoseImageProcessor(BaseImageProcessor):
r"""
Constructs a VitPose image processor.
Args:
do_affine_transform (`bool`, *optional*, defaults to `True`):
Whether to apply an affine transformation to the input images.
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 192}`):
Resolution of the image after `affine_transform` is applied. Only has an effect if `do_affine_transform` is set to `True`. Can
be overriden by `size` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`, *optional*):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`, *optional*):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_affine_transform: bool = True,
size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.do_affine_transform = do_affine_transform
self.size = size if size is not None else {"height": 256, "width": 192}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.normalize_factor = 200.0
def affine_transform(
self,
image: np.array,
center: Tuple[float],
scale: Tuple[float],
rotation: float,
size: Dict[str, int],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Apply an affine transformation to an image.
Args:
image (`np.array`):
Image to transform.
center (`Tuple[float]`):
Center of the bounding box (x, y).
scale (`Tuple[float]`):
Scale of the bounding box with respect to height/width.
rotation (`float`):
Rotation angle in degrees.
size (`Dict[str, int]`):
Size of the destination image.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format of the output image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image.
"""
data_format = input_data_format if data_format is None else data_format
size = (size["width"], size["height"])
# one uses a pixel standard deviation of 200 pixels
transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0)
# input image requires channels last format
image = (
image
if input_data_format == ChannelDimension.LAST
else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
)
image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0]))
image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST)
return image
def preprocess(
self,
images: ImageInput,
boxes: Union[List[List[float]], np.ndarray],
do_affine_transform: bool = None,
size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
boxes (`List[List[List[float]]]` or `np.ndarray`):
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
box coordinates in COCO format (top_left_x, top_left_y, width, height).
do_affine_transform (`bool`, *optional*, defaults to `self.do_affine_transform`):
Whether to apply an affine transformation to the input images.
size (`Dict[str, int]` *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
do_affine_transform = do_affine_transform if do_affine_transform is not None else self.do_affine_transform
size = size if size is not None else self.size
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if isinstance(boxes, list) and len(images) != len(boxes):
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}")
elif isinstance(boxes, np.ndarray) and len(images) != boxes.shape[0]:
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {boxes.shape[0]}")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# transformations (affine transformation + rescaling + normalization)
if self.do_affine_transform:
new_images = []
for image, image_boxes in zip(images, boxes):
for box in image_boxes:
center, scale = box_to_center_and_scale(
box,
image_width=size["width"],
image_height=size["height"],
normalize_factor=self.normalize_factor,
)
transformed_image = self.affine_transform(
image, center, scale, rotation=0, size=size, input_data_format=input_data_format
)
new_images.append(transformed_image)
images = new_images
# For batch processing, the number of boxes must be consistent across all images in the batch.
# When using a list input, the number of boxes can vary dynamically per image.
# The image processor creates pixel_values of shape (batch_size*num_persons, num_channels, height, width)
all_images = []
for image in images:
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
all_images.append(image)
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in all_images
]
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def keypoints_from_heatmaps(
self,
heatmaps: np.ndarray,
center: np.ndarray,
scale: np.ndarray,
kernel: int = 11,
):
"""
Get final keypoint predictions from heatmaps and transform them back to
the image.
Args:
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width])`):
Model predicted heatmaps.
center (`np.ndarray` of shape `(batch_size, 2)`):
Center of the bounding box (x, y).
scale (`np.ndarray` of shape `(batch_size, 2)`):
Scale of the bounding box wrt original images of width and height.
kernel (int, *optional*, defaults to 11):
Gaussian kernel size (K) for modulation, which should match the heatmap gaussian sigma when training.
K=17 for sigma=3 and k=11 for sigma=2.
Returns:
tuple: A tuple containing keypoint predictions and scores.
- preds (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
Predicted keypoint location in images.
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
Scores (confidence) of the keypoints.
"""
batch_size, _, height, width = heatmaps.shape
coords, scores = get_keypoint_predictions(heatmaps)
preds = post_dark_unbiased_data_processing(coords, heatmaps, kernel=kernel)
# Transform back to the image
for i in range(batch_size):
preds[i] = transform_preds(preds[i], center=center[i], scale=scale[i], output_size=[height, width])
return preds, scores
def post_process_pose_estimation(
self,
outputs: "VitPoseEstimatorOutput",
boxes: Union[List[List[List[float]]], np.ndarray],
kernel_size: int = 11,
threshold: float = None,
target_sizes: Union[TensorType, List[Tuple]] = None,
):
"""
Transform the heatmaps into keypoint predictions and transform them back to the image.
Args:
outputs (`VitPoseEstimatorOutput`):
VitPoseForPoseEstimation model outputs.
boxes (`List[List[List[float]]]` or `np.ndarray`):
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
box coordinates in COCO format (top_left_x, top_left_y, width, height).
kernel_size (`int`, *optional*, defaults to 11):
Gaussian kernel size (K) for modulation.
threshold (`float`, *optional*, defaults to None):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will be resize with the default value.
Returns:
`List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image
in the batch as predicted by the model.
"""
# First compute centers and scales for each bounding box
batch_size, num_keypoints, _, _ = outputs.heatmaps.shape
if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
centers = np.zeros((batch_size, 2), dtype=np.float32)
scales = np.zeros((batch_size, 2), dtype=np.float32)
flattened_boxes = list(itertools.chain(*boxes))
for i in range(batch_size):
if target_sizes is not None:
image_width, image_height = target_sizes[i][0], target_sizes[i][1]
scale_factor = np.array([image_width, image_height, image_width, image_height])
flattened_boxes[i] = flattened_boxes[i] * scale_factor
width, height = self.size["width"], self.size["height"]
center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height)
centers[i, :] = center
scales[i, :] = scale
preds, scores = self.keypoints_from_heatmaps(
outputs.heatmaps.cpu().numpy(), centers, scales, kernel=kernel_size
)
all_boxes = np.zeros((batch_size, 4), dtype=np.float32)
all_boxes[:, 0:2] = centers[:, 0:2]
all_boxes[:, 2:4] = scales[:, 0:2]
poses = torch.tensor(preds)
scores = torch.tensor(scores)
labels = torch.arange(0, num_keypoints)
bboxes_xyxy = torch.tensor(coco_to_pascal_voc(all_boxes))
results: List[List[Dict[str, torch.Tensor]]] = []
pose_bbox_pairs = zip(poses, scores, bboxes_xyxy)
for image_bboxes in boxes:
image_results: List[Dict[str, torch.Tensor]] = []
for _ in image_bboxes:
# Unpack the next pose and bbox_xyxy from the iterator
pose, score, bbox_xyxy = next(pose_bbox_pairs)
score = score.squeeze()
keypoints_labels = labels
if threshold is not None:
keep = score > threshold
pose = pose[keep]
score = score[keep]
keypoints_labels = keypoints_labels[keep]
pose_result = {"keypoints": pose, "scores": score, "labels": keypoints_labels, "bbox": bbox_xyxy}
image_results.append(pose_result)
results.append(image_results)
return results
__all__ = ["VitPoseImageProcessor"]

View File

@ -0,0 +1,340 @@
# coding=utf-8
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch VitPose model."""
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import load_backbone
from .configuration_vitpose import VitPoseConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "VitPoseConfig"
@dataclass
class VitPoseEstimatorOutput(ModelOutput):
"""
Class for outputs of pose estimation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
Heatmaps as predicted by the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
heatmaps: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class VitPosePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VitPoseConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
VITPOSE_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`VitPoseConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VITPOSE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`VitPoseImageProcessor`]. See
[`VitPoseImageProcessor.__call__`] for details.
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
flip_pairs (`torch.tensor`, *optional*):
Whether to mirror pairs of keypoints (for example, left ear -- right ear).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
"""Flip the flipped heatmaps back to the original form.
Args:
output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
The output heatmaps obtained from the flipped images.
flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
Pairs of keypoints which are mirrored (for example, left ear -- right ear).
target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
Target type to use. Can be gaussian-heatmap or combined-target.
gaussian-heatmap: Classification target with gaussian distribution.
combined-target: The combination of classification target (response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Returns:
torch.Tensor: heatmaps that flipped back to the original image
"""
if target_type not in ["gaussian-heatmap", "combined-target"]:
raise ValueError("target_type should be gaussian-heatmap or combined-target")
if output_flipped.ndim != 4:
raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
batch_size, num_keypoints, height, width = output_flipped.shape
channels = 1
if target_type == "combined-target":
channels = 3
output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
output_flipped_back = output_flipped.clone()
# Swap left-right parts
for left, right in flip_pairs.tolist():
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
# Flip horizontally
output_flipped_back = output_flipped_back.flip(-1)
return output_flipped_back
class VitPoseSimpleDecoder(nn.Module):
"""
Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
feature maps into heatmaps.
"""
def __init__(self, config) -> None:
super().__init__()
self.activation = nn.ReLU()
self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
self.conv = nn.Conv2d(
config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
)
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
# Transform input: ReLU + upsample
hidden_state = self.activation(hidden_state)
hidden_state = self.upsampling(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
class VitPoseClassicDecoder(nn.Module):
"""
Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
turning the feature maps into heatmaps.
"""
def __init__(self, config: VitPoseConfig):
super().__init__()
self.deconv1 = nn.ConvTranspose2d(
config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
)
self.batchnorm1 = nn.BatchNorm2d(256)
self.relu1 = nn.ReLU()
self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
self.batchnorm2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
hidden_state = self.deconv1(hidden_state)
hidden_state = self.batchnorm1(hidden_state)
hidden_state = self.relu1(hidden_state)
hidden_state = self.deconv2(hidden_state)
hidden_state = self.batchnorm2(hidden_state)
hidden_state = self.relu2(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
@add_start_docstrings(
"The VitPose model with a pose estimation head on top.",
VITPOSE_START_DOCSTRING,
)
class VitPoseForPoseEstimation(VitPosePreTrainedModel):
def __init__(self, config: VitPoseConfig) -> None:
super().__init__(config)
self.backbone = load_backbone(config)
# add backbone attributes
if not hasattr(self.backbone.config, "hidden_size"):
raise ValueError("The backbone should have a hidden_size attribute")
if not hasattr(self.backbone.config, "image_size"):
raise ValueError("The backbone should have an image_size attribute")
if not hasattr(self.backbone.config, "patch_size"):
raise ValueError("The backbone should have a patch_size attribute")
self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VITPOSE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VitPoseEstimatorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
dataset_index: Optional[torch.Tensor] = None,
flip_pairs: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, VitPoseEstimatorOutput]:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
>>> import torch
>>> from PIL import Image
>>> import requests
>>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
>>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
>>> inputs = processor(image, boxes=boxes, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> heatmaps = outputs.heatmaps
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values,
dataset_index=dataset_index,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
# Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1]
batch_size = sequence_output.shape[0]
patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
sequence_output = (
sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous()
)
heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
if not return_dict:
if output_hidden_states:
output = (heatmaps,) + outputs[1:]
else:
output = (heatmaps,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return VitPoseEstimatorOutput(
loss=loss,
heatmaps=heatmaps,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]

View File

@ -0,0 +1,54 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_vitpose_backbone": ["VitPoseBackboneConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vitpose_backbone"] = [
"VitPoseBackbonePreTrainedModel",
"VitPoseBackbone",
]
if TYPE_CHECKING:
from .configuration_vitpose_backbone import VitPoseBackboneConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vitpose_backbone import (
VitPoseBackbone,
VitPoseBackbonePreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,136 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VitPose backbone configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VitPose
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to `[256, 192]`):
The size (resolution) of each image.
patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
mlp_ratio (`int`, *optional*, defaults to 4):
The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
num_experts (`int`, *optional*, defaults to 1):
The number of experts in the MoE layer.
part_features (`int`, *optional*):
The number of part features to output. Only used in case `num_experts` is greater than 1.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example:
```python
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
>>> # Initializing a VitPose configuration
>>> configuration = VitPoseBackboneConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = VitPoseBackbone(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "vitpose_backbone"
def __init__(
self,
image_size=[256, 192],
patch_size=[16, 16],
num_channels=3,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
mlp_ratio=4,
num_experts=1,
part_features=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-12,
qkv_bias=True,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_ratio = mlp_ratio
self.num_experts = num_experts
self.part_features = part_features
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)

View File

@ -0,0 +1,542 @@
# coding=utf-8
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch VitPose backbone model.
This code is the same as the original Vision Transformer (ViT) with 2 modifications:
- use of padding=2 in the patch embedding layer
- addition of a mixture-of-experts MLP layer
"""
import collections.abc
import math
from typing import Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput, BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_vitpose_backbone import VitPoseBackboneConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "VitPoseBackboneConfig"
class VitPoseBackbonePatchEmbeddings(nn.Module):
"""Image to Patch Embedding."""
def __init__(self, config):
super().__init__()
image_size = config.image_size
patch_size = config.patch_size
num_channels = config.num_channels
embed_dim = config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
height, width = pixel_values.shape[-2:]
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings
class VitPoseBackboneEmbeddings(nn.Module):
"""
Construct the position and patch embeddings.
"""
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)
# add positional encoding to each token
embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
embeddings = self.dropout(embeddings)
return embeddings
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
class VitPoseBackboneSelfAttention(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
class VitPoseBackboneSelfOutput(nn.Module):
"""
The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
class VitPoseBackboneAttention(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.attention = VitPoseBackboneSelfAttention(config)
self.output = VitPoseBackboneSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class VitPoseBackboneMoeMLP(nn.Module):
def __init__(self, config: VitPoseBackboneConfig):
super().__init__()
in_features = out_features = config.hidden_size
hidden_features = int(config.hidden_size * config.mlp_ratio)
num_experts = config.num_experts
part_features = config.part_features
self.part_features = part_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = ACT2FN[config.hidden_act]
self.fc2 = nn.Linear(hidden_features, out_features - part_features)
self.drop = nn.Dropout(config.hidden_dropout_prob)
self.num_experts = num_experts
experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
self.experts = nn.ModuleList(experts)
def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
hidden_state = self.fc1(hidden_state)
hidden_state = self.act(hidden_state)
shared_hidden_state = self.fc2(hidden_state)
indices = indices.view(-1, 1, 1)
# to support ddp training
for i in range(self.num_experts):
selected_index = indices == i
current_hidden_state = self.experts[i](hidden_state) * selected_index
expert_hidden_state = expert_hidden_state + current_hidden_state
hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
return hidden_state
class VitPoseBackboneMLP(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
in_features = out_features = config.hidden_size
hidden_features = int(config.hidden_size * config.mlp_ratio)
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
self.activation = ACT2FN[config.hidden_act]
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class VitPoseBackboneLayer(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.num_experts = config.num_experts
self.attention = VitPoseBackboneAttention(config)
self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
dataset_index: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# Validate dataset_index when using multiple experts
if self.num_experts > 1 and dataset_index is None:
raise ValueError(
"dataset_index must be provided when using multiple experts "
f"(num_experts={self.num_experts}). Please provide dataset_index "
"to the forward pass."
)
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
if self.num_experts == 1:
layer_output = self.mlp(layer_output)
else:
layer_output = self.mlp(layer_output, indices=dataset_index)
# second residual connection
layer_output = layer_output + hidden_states
outputs = (layer_output,) + outputs
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
class VitPoseBackboneEncoder(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
dataset_index: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
dataset_index,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class VitPoseBackbonePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VitPoseBackboneConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, VitPoseBackboneEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
VITPOSE_BACKBONE_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The VitPose backbone useful for downstream tasks.",
VITPOSE_BACKBONE_START_DOCSTRING,
)
class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
def __init__(self, config: VitPoseBackboneConfig):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
self.embeddings = VitPoseBackboneEmbeddings(config)
self.encoder = VitPoseBackboneEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
dataset_index: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Returns:
Examples:
```python
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
>>> import torch
>>> config = VitPoseBackboneConfig(out_indices=[-1])
>>> model = VitPoseBackbone(config)
>>> pixel_values = torch.randn(1, 3, 256, 192)
>>> dataset_index = torch.tensor([1])
>>> outputs = model(pixel_values, dataset_index)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
dataset_index=dataset_index,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
hidden_state = self.layernorm(hidden_state)
feature_maps += (hidden_state,)
if not return_dict:
if output_hidden_states:
output = (feature_maps,) + outputs[1:]
else:
output = (feature_maps,) + outputs[2:]
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -9753,6 +9753,34 @@ class VitMattePreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class VitPoseForPoseEstimation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VitPosePreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VitPoseBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VitPoseBackbonePreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VitsModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -695,6 +695,13 @@ class VitMatteImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class VitPoseImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class VivitImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

View File

View File

@ -0,0 +1,229 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import VitPoseImageProcessor
class VitPoseImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_affine_transform=True,
size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 20, "width": 20}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_affine_transform = do_affine_transform
self.size = size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_image_processor_dict(self):
return {
"do_affine_transform": self.do_affine_transform,
"size": self.size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class VitPoseImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VitPoseImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = VitPoseImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_affine_transform"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 42, "width": 42}
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
# Test batched
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
# Test batched
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
# Test batched
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
)
def test_call_numpy_4_channels(self):
# Test that can process images which have an arbitrary number of channels
# Initialize image_processing
image_processor = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
self.image_processor_tester.num_channels = 4
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
# Test not batched input
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
encoded_images = image_processor(
image_inputs[0],
boxes=boxes,
return_tensors="pt",
input_data_format="channels_last",
image_mean=0,
image_std=1,
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (len(boxes[0]), *expected_output_image_shape))
# Test batched
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
encoded_images = image_processor(
image_inputs,
boxes=boxes,
return_tensors="pt",
input_data_format="channels_last",
image_mean=0,
image_std=1,
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size * len(boxes[0]), *expected_output_image_shape),
)

View File

@ -0,0 +1,332 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch VitPose model."""
import inspect
import unittest
import requests
from transformers import VitPoseBackboneConfig, VitPoseConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import VitPoseForPoseEstimation
if is_vision_available():
from PIL import Image
from transformers import VitPoseImageProcessor
class VitPoseModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=[16 * 8, 12 * 8],
patch_size=[8, 8],
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=2,
scale_factor=4,
out_indices=[-1],
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scale_factor = scale_factor
self.out_indices = out_indices
self.scope = scope
# in VitPose, the seq length equals the number of patches
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
self.seq_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return VitPoseConfig(
backbone_config=self.get_backbone_config(),
)
def get_backbone_config(self):
return VitPoseBackboneConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
hidden_act=self.hidden_act,
out_indices=self.out_indices,
)
def create_and_check_for_pose_estimation(self, config, pixel_values, labels):
model = VitPoseForPoseEstimation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
expected_height = (self.image_size[0] // self.patch_size[0]) * self.scale_factor
expected_width = (self.image_size[1] // self.patch_size[1]) * self.scale_factor
self.parent.assertEqual(
result.heatmaps.shape, (self.batch_size, self.num_labels, expected_height, expected_width)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class VitPoseModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as VitPose does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (VitPoseForPoseEstimation,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = VitPoseModelTester(self)
self.config_tester = ConfigTester(self, config_class=VitPoseConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
@unittest.skip(reason="VitPose does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="VitPose does not support input and output embeddings")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="VitPose does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="VitPose does not support training yet")
def test_training(self):
pass
@unittest.skip(reason="VitPose does not support training yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="VitPose does not support training yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="VitPose does not support training yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_for_pose_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model_name = "usyd-community/vitpose-base-simple"
model = VitPoseForPoseEstimation.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of people in house
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
@require_torch
@require_vision
class VitPoseModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return (
VitPoseImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
if is_vision_available()
else None
)
@slow
def test_inference_pose_estimation(self):
image_processor = self.default_image_processor
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
model.to(torch_device)
model.eval()
image = prepare_img()
boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
inputs = image_processor(images=image, boxes=boxes, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
heatmaps = outputs.heatmaps
assert heatmaps.shape == (2, 17, 64, 48)
expected_slice = torch.tensor(
[
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
]
).to(torch_device)
assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
expected_bbox = torch.tensor([391.9900, 190.0800, 391.1575, 189.3034])
expected_keypoints = torch.tensor(
[
[3.9813e02, 1.8184e02],
[3.9828e02, 1.7981e02],
[3.9596e02, 1.7948e02],
]
)
expected_scores = torch.tensor([8.7529e-01, 8.4315e-01, 9.2678e-01])
self.assertEqual(len(pose_results), 2)
self.assertTrue(torch.allclose(pose_results[1]["bbox"].cpu(), expected_bbox, atol=1e-4))
self.assertTrue(torch.allclose(pose_results[1]["keypoints"][:3].cpu(), expected_keypoints, atol=1e-2))
self.assertTrue(torch.allclose(pose_results[1]["scores"][:3].cpu(), expected_scores, atol=1e-4))
@slow
def test_batched_inference(self):
image_processor = self.default_image_processor
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
model.to(torch_device)
model.eval()
image = prepare_img()
boxes = [
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
]
inputs = image_processor(images=[image, image], boxes=boxes, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
heatmaps = outputs.heatmaps
assert heatmaps.shape == (4, 17, 64, 48)
expected_slice = torch.tensor(
[
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
]
).to(torch_device)
assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)
print(pose_results)
expected_bbox = torch.tensor([391.9900, 190.0800, 391.1575, 189.3034])
expected_keypoints = torch.tensor(
[
[3.9813e02, 1.8184e02],
[3.9828e02, 1.7981e02],
[3.9596e02, 1.7948e02],
]
)
expected_scores = torch.tensor([8.7529e-01, 8.4315e-01, 9.2678e-01])
self.assertEqual(len(pose_results), 2)
self.assertEqual(len(pose_results[0]), 2)
self.assertTrue(torch.allclose(pose_results[0][1]["bbox"].cpu(), expected_bbox, atol=1e-4))
self.assertTrue(torch.allclose(pose_results[0][1]["keypoints"][:3].cpu(), expected_keypoints, atol=1e-2))
self.assertTrue(torch.allclose(pose_results[0][1]["scores"][:3].cpu(), expected_scores, atol=1e-4))

View File

@ -0,0 +1,199 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch VitPose backbone model."""
import inspect
import unittest
from transformers import VitPoseBackboneConfig
from transformers.testing_utils import require_torch
from transformers.utils import is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
from transformers import VitPoseBackbone
if is_vision_available():
pass
class VitPoseBackboneModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=[16 * 8, 12 * 8],
patch_size=[8, 8],
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
# in VitPoseBackbone, the seq length equals the number of patches
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
self.seq_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return VitPoseBackboneConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
num_labels=self.num_labels,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as VitPoseBackbone does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = VitPoseBackboneModelTester(self)
self.config_tester = ConfigTester(
self, config_class=VitPoseBackboneConfig, has_text_modality=False, hidden_size=37
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="VitPoseBackbone does not output a loss")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support training yet")
def test_training(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support training yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support training yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="VitPoseBackbone does not support training yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@require_torch
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
config_class = VitPoseBackboneConfig
has_attentions = False
def setUp(self):
self.model_tester = VitPoseBackboneModelTester(self)

View File

@ -1092,6 +1092,7 @@ MODELS_NOT_IN_README = [
"CLIPVisionModel",
"SiglipVisionModel",
"ChineseCLIPVisionModel",
"VitPoseBackbone",
]
# Template for new entries to add in the main README when we have missing models.

View File

@ -330,6 +330,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SiglipVisionModel",
"SiglipTextModel",
"ChameleonVQVAE", # no autoclass for VQ-VAE models
"VitPoseForPoseEstimation",
"CLIPTextModel",
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
]
@ -993,6 +994,8 @@ UNDOCUMENTED_OBJECTS = [
"logging", # External module
"requires_backends", # Internal function
"AltRobertaModel", # Internal module
"VitPoseBackbone", # Internal module
"VitPoseBackboneConfig", # Internal module
]
# This list should be empty. Objects in it should get their own doc page.