mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add ViTPose (#30530)
* First draft * Make fixup * Make forward pass worké * Improve code * More improvements * More improvements * Make predictions match * More improvements * Improve image processor * Fix model tests * Add classic decoder * Convert classic decoder * Verify image processor * Fix classic decoder logits * Clean up * Add post_process_pose_estimation * Improve post_process_pose_estimation * Use AutoBackbone * Add support for MoE models * Fix tests, improve num_experts% * Improve variable names * Make fixup * More improvements * Improve post_process_pose_estimation * Compute centers and scales * Improve postprocessing * More improvements * Fix ViTPoseBackbone tests * Add docstrings, fix image processor tests * Update index * Use is_cv2_available * Add model to toctree * Add cv2 to doc tests * Remove script * Improve conversion script * Add coco_to_pascal_voc * Add box_to_center_and_scale to image_transforms * Update tests * Add integration test * Fix merge * Address comments * Replace numpy by pytorch, improve docstrings * Remove get_input_embeddings * Address comments * Move coco_to_pascal_voc * Address comment * Fix style * Address comments * Fix test * Address comment * Remove udp * Remove comment * [WIP] need to check if the numpy function is same as cv * add scipy affine_transform * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * refactor convert * add output_shape * add atol 5e-2 * Use hf_hub_download in conversion script * make box_to_center more applicable * skipt test_get_set_embedding * fix to accept array and fix CI * add co-contributor * make it to tensor type output * add torch * change to torch tensor * add more test * minor change * CI test change * import torch should be above ImageProcessor * make style * try not use torch in def * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fix * fix * add caution * make more detail about dataset_index * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * add docs * Update docs/source/en/model_doc/vitpose.md * Update src/transformers/models/vitpose/configuration_vitpose.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Revert "Update src/transformers/__init__.py" This reverts commit7ffa504450
. * change name * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vitpose/test_modeling_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vitpose.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * move vitpose only function to image_processor * raise valueerror when using timm backbone * use out_indices * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove camel-case of def flip_back * rename vitposeEstimatorOutput * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix confused camelcase of MLP * remove in-place logic * clear scale description * make consistent batch format * docs update * formatting docstring * add batch tests * test docs change * Update src/transformers/models/vitpose/image_processing_vitpose.py * Update src/transformers/models/vitpose/configuration_vitpose.py * chagne ViT to Vit * change to enable MoE * make fix-copies * Update docs/source/en/model_doc/vitpose.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * extract udp * add more described docs * simple fix * change to accept target_size * make style * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vitpose/configuration_vitpose.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change to `verify_backbone_config_arguments` * Update docs/source/en/model_doc/vitpose.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove unnecessary copy * make config immutable * enable gradient checkpointing * update inappropriate docstring * linting docs * split function for visibility * make style * check isinstances * change to acceptable use_pretrained_backbone * make style * remove copy in docs * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update docs/source/en/model_doc/vitpose.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/vitpose/modeling_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * simple fix + make style * change input config of activation function to string * Update docs/source/en/model_doc/vitpose.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * tmp docs * delete index.md * make fix-copies * simple fix * change conversion to sam2/mllama style * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/vitpose/image_processing_vitpose.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * refactor convert * add supervision * Update src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * remove reduntant def * seperate code block for visualization * add validation for num_moe * final commit * add labels * [run-slow] vitpose, vitpose_backbone * Update src/transformers/models/vitpose/convert_vitpose_to_hf.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * enable all conversion * final commit * [run-slow] vitpose, vitpose_backbone * ruff check --fix * [run-slow] vitpose, vitpose_backbone * rename split module * [run-slow] vitpose, vitpose_backbone * fix pos_embed * Simplify init * Revert "fix pos_embed" This reverts commit2c56a4806e
. * refactor single loop * allow flag to enable custom model * efficiency of MoE to not use unused experts * make style * Fix range -> arange to avoid warning * Revert MOE router, a new one does not work * Fix postprocessing a bit (labels) * Fix type hint * Fix docs snippets * Fix links to checkpoints * Fix checkpoints in tests * Fix test * Add image to docs --------- Co-authored-by: Niels Rogge <nielsrogge@nielss-mbp.home> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: sangbumchoi <danielsejong55@gmail.com> Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
4349a0e401
commit
8490d3159c
@ -741,6 +741,8 @@
|
||||
title: ViTMatte
|
||||
- local: model_doc/vit_msn
|
||||
title: ViTMSN
|
||||
- local: model_doc/vitpose
|
||||
title: ViTPose
|
||||
- local: model_doc/yolos
|
||||
title: YOLOS
|
||||
- local: model_doc/zoedepth
|
||||
|
@ -356,6 +356,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [ViTMAE](model_doc/vit_mae) | ✅ | ✅ | ❌ |
|
||||
| [ViTMatte](model_doc/vitmatte) | ✅ | ❌ | ❌ |
|
||||
| [ViTMSN](model_doc/vit_msn) | ✅ | ❌ | ❌ |
|
||||
| [VitPose](model_doc/vitpose) | ✅ | ❌ | ❌ |
|
||||
| [VitPoseBackbone](model_doc/vitpose_backbone) | ✅ | ❌ | ❌ |
|
||||
| [VITS](model_doc/vits) | ✅ | ❌ | ❌ |
|
||||
| [ViViT](model_doc/vivit) | ✅ | ❌ | ❌ |
|
||||
| [Wav2Vec2](model_doc/wav2vec2) | ✅ | ✅ | ✅ |
|
||||
|
254
docs/source/en/model_doc/vitpose.md
Normal file
254
docs/source/en/model_doc/vitpose.md
Normal file
@ -0,0 +1,254 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VitPose
|
||||
|
||||
## Overview
|
||||
|
||||
The VitPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://arxiv.org/abs/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. VitPose employs a standard, non-hierarchical [Vision Transformer](https://arxiv.org/pdf/2010.11929v2) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Although no specific domain knowledge is considered in the design, plain vision transformers have shown excellent performance in visual recognition tasks. However, little effort has been made to reveal the potential of such simple structures for pose estimation tasks. In this paper, we show the surprisingly good capabilities of plain vision transformers for pose estimation from various aspects, namely simplicity in model structure, scalability in model size, flexibility in training paradigm, and transferability of knowledge between models, through a simple baseline model called ViTPose. Specifically, ViTPose employs plain and non-hierarchical vision transformers as backbones to extract features for a given person instance and a lightweight decoder for pose estimation. It can be scaled up from 100M to 1B parameters by taking the advantages of the scalable model capacity and high parallelism of transformers, setting a new Pareto front between throughput and performance. Besides, ViTPose is very flexible regarding the attention type, input resolution, pre-training and finetuning strategy, as well as dealing with multiple pose tasks. We also empirically demonstrate that the knowledge of large ViTPose models can be easily transferred to small ones via a simple knowledge token. Experimental results show that our basic ViTPose model outperforms representative methods on the challenging MS COCO Keypoint Detection benchmark, while the largest model sets a new state-of-the-art.*
|
||||
|
||||

|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr) and [sangbumchoi](https://github.com/SangbumChoi).
|
||||
The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPose).
|
||||
|
||||
## Usage Tips
|
||||
|
||||
ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
RTDetrForObjectDetection,
|
||||
VitPoseForPoseEstimation,
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Stage 1. Detect humans on the image
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
# You can choose detector by your choice
|
||||
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
||||
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
|
||||
|
||||
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = person_model(**inputs)
|
||||
|
||||
results = person_image_processor.post_process_object_detection(
|
||||
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
||||
)
|
||||
result = results[0] # take first image results
|
||||
|
||||
# Human label refers 0 index in COCO dataset
|
||||
person_boxes = result["boxes"][result["labels"] == 0]
|
||||
person_boxes = person_boxes.cpu().numpy()
|
||||
|
||||
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
|
||||
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
||||
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Stage 2. Detect keypoints for each person found
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
|
||||
|
||||
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
|
||||
image_pose_result = pose_results[0] # results for first image
|
||||
```
|
||||
|
||||
|
||||
### Visualization for supervision user
|
||||
```py
|
||||
import supervision as sv
|
||||
|
||||
xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
|
||||
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()
|
||||
|
||||
key_points = sv.KeyPoints(
|
||||
xy=xy, confidence=scores
|
||||
)
|
||||
|
||||
edge_annotator = sv.EdgeAnnotator(
|
||||
color=sv.Color.GREEN,
|
||||
thickness=1
|
||||
)
|
||||
vertex_annotator = sv.VertexAnnotator(
|
||||
color=sv.Color.RED,
|
||||
radius=2
|
||||
)
|
||||
annotated_frame = edge_annotator.annotate(
|
||||
scene=image.copy(),
|
||||
key_points=key_points
|
||||
)
|
||||
annotated_frame = vertex_annotator.annotate(
|
||||
scene=annotated_frame,
|
||||
key_points=key_points
|
||||
)
|
||||
```
|
||||
|
||||
### Visualization for advanced user
|
||||
```py
|
||||
import math
|
||||
import cv2
|
||||
|
||||
def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
|
||||
if pose_keypoint_color is not None:
|
||||
assert len(pose_keypoint_color) == len(keypoints)
|
||||
for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
|
||||
x_coord, y_coord = int(kpt[0]), int(kpt[1])
|
||||
if kpt_score > keypoint_score_threshold:
|
||||
color = tuple(int(c) for c in pose_keypoint_color[kid])
|
||||
if show_keypoint_weight:
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
transparency = max(0, min(1, kpt_score))
|
||||
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
||||
else:
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
|
||||
def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
|
||||
height, width, _ = image.shape
|
||||
if keypoint_edges is not None and link_colors is not None:
|
||||
assert len(link_colors) == len(keypoint_edges)
|
||||
for sk_id, sk in enumerate(keypoint_edges):
|
||||
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
|
||||
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
|
||||
if (
|
||||
x1 > 0
|
||||
and x1 < width
|
||||
and y1 > 0
|
||||
and y1 < height
|
||||
and x2 > 0
|
||||
and x2 < width
|
||||
and y2 > 0
|
||||
and y2 < height
|
||||
and score1 > keypoint_score_threshold
|
||||
and score2 > keypoint_score_threshold
|
||||
):
|
||||
color = tuple(int(c) for c in link_colors[sk_id])
|
||||
if show_keypoint_weight:
|
||||
X = (x1, x2)
|
||||
Y = (y1, y2)
|
||||
mean_x = np.mean(X)
|
||||
mean_y = np.mean(Y)
|
||||
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
|
||||
)
|
||||
cv2.fillConvexPoly(image, polygon, color)
|
||||
transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
|
||||
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
||||
else:
|
||||
cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
|
||||
|
||||
|
||||
# Note: keypoint_edges and color palette are dataset-specific
|
||||
keypoint_edges = model.config.edges
|
||||
|
||||
palette = np.array(
|
||||
[
|
||||
[255, 128, 0],
|
||||
[255, 153, 51],
|
||||
[255, 178, 102],
|
||||
[230, 230, 0],
|
||||
[255, 153, 255],
|
||||
[153, 204, 255],
|
||||
[255, 102, 255],
|
||||
[255, 51, 255],
|
||||
[102, 178, 255],
|
||||
[51, 153, 255],
|
||||
[255, 153, 153],
|
||||
[255, 102, 102],
|
||||
[255, 51, 51],
|
||||
[153, 255, 153],
|
||||
[102, 255, 102],
|
||||
[51, 255, 51],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[255, 255, 255],
|
||||
]
|
||||
)
|
||||
|
||||
link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
|
||||
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
numpy_image = np.array(image)
|
||||
|
||||
for pose_result in image_pose_result:
|
||||
scores = np.array(pose_result["scores"])
|
||||
keypoints = np.array(pose_result["keypoints"])
|
||||
|
||||
# draw each point on image
|
||||
draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
|
||||
|
||||
# draw links
|
||||
draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
|
||||
|
||||
pose_image = Image.fromarray(numpy_image)
|
||||
pose_image
|
||||
```
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>
|
||||
|
||||
### MoE backbone
|
||||
|
||||
To enable MoE (Mixture of Experts) function in the backbone, user has to give appropriate configuration such as `num_experts` and input value `dataset_index` to the backbone model. However, it is not used in default parameters. Below is the code snippet for usage of MoE function.
|
||||
|
||||
```py
|
||||
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
|
||||
>>> import torch
|
||||
|
||||
>>> config = VitPoseBackboneConfig(num_experts=3, out_indices=[-1])
|
||||
>>> model = VitPoseBackbone(config)
|
||||
|
||||
>>> pixel_values = torch.randn(3, 3, 256, 192)
|
||||
>>> dataset_index = torch.tensor([1, 2, 3])
|
||||
>>> outputs = model(pixel_values, dataset_index)
|
||||
```
|
||||
|
||||
## VitPoseImageProcessor
|
||||
|
||||
[[autodoc]] VitPoseImageProcessor
|
||||
- preprocess
|
||||
|
||||
## VitPoseConfig
|
||||
|
||||
[[autodoc]] VitPoseConfig
|
||||
|
||||
## VitPoseForPoseEstimation
|
||||
|
||||
[[autodoc]] VitPoseForPoseEstimation
|
||||
- forward
|
@ -834,6 +834,8 @@ _import_structure = {
|
||||
"models.vit_msn": ["ViTMSNConfig"],
|
||||
"models.vitdet": ["VitDetConfig"],
|
||||
"models.vitmatte": ["VitMatteConfig"],
|
||||
"models.vitpose": ["VitPoseConfig"],
|
||||
"models.vitpose_backbone": ["VitPoseBackboneConfig"],
|
||||
"models.vits": [
|
||||
"VitsConfig",
|
||||
"VitsTokenizer",
|
||||
@ -1266,6 +1268,7 @@ else:
|
||||
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
|
||||
_import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
|
||||
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
|
||||
_import_structure["models.vitpose"].append("VitPoseImageProcessor")
|
||||
_import_structure["models.vivit"].append("VivitImageProcessor")
|
||||
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
|
||||
_import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
|
||||
@ -3755,6 +3758,18 @@ else:
|
||||
"VitMattePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vitpose"].extend(
|
||||
[
|
||||
"VitPoseForPoseEstimation",
|
||||
"VitPosePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vitpose_backbone"].extend(
|
||||
[
|
||||
"VitPoseBackbone",
|
||||
"VitPoseBackbonePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vits"].extend(
|
||||
[
|
||||
"VitsModel",
|
||||
@ -5877,6 +5892,8 @@ if TYPE_CHECKING:
|
||||
from .models.vit_msn import ViTMSNConfig
|
||||
from .models.vitdet import VitDetConfig
|
||||
from .models.vitmatte import VitMatteConfig
|
||||
from .models.vitpose import VitPoseConfig
|
||||
from .models.vitpose_backbone import VitPoseBackboneConfig
|
||||
from .models.vits import (
|
||||
VitsConfig,
|
||||
VitsTokenizer,
|
||||
@ -6311,6 +6328,7 @@ if TYPE_CHECKING:
|
||||
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
|
||||
from .models.vit import ViTFeatureExtractor, ViTImageProcessor
|
||||
from .models.vitmatte import VitMatteImageProcessor
|
||||
from .models.vitpose import VitPoseImageProcessor
|
||||
from .models.vivit import VivitImageProcessor
|
||||
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
|
||||
from .models.zoedepth import ZoeDepthImageProcessor
|
||||
@ -8294,6 +8312,11 @@ if TYPE_CHECKING:
|
||||
VitMatteForImageMatting,
|
||||
VitMattePreTrainedModel,
|
||||
)
|
||||
from .models.vitpose import (
|
||||
VitPoseForPoseEstimation,
|
||||
VitPosePreTrainedModel,
|
||||
)
|
||||
from .models.vitpose_backbone import VitPoseBackbone, VitPoseBackbonePreTrainedModel
|
||||
from .models.vits import (
|
||||
VitsModel,
|
||||
VitsPreTrainedModel,
|
||||
|
@ -277,6 +277,8 @@ from . import (
|
||||
vit_msn,
|
||||
vitdet,
|
||||
vitmatte,
|
||||
vitpose,
|
||||
vitpose_backbone,
|
||||
vits,
|
||||
vivit,
|
||||
wav2vec2,
|
||||
|
@ -309,6 +309,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("vit_msn", "ViTMSNConfig"),
|
||||
("vitdet", "VitDetConfig"),
|
||||
("vitmatte", "VitMatteConfig"),
|
||||
("vitpose", "VitPoseConfig"),
|
||||
("vitpose_backbone", "VitPoseBackboneConfig"),
|
||||
("vits", "VitsConfig"),
|
||||
("vivit", "VivitConfig"),
|
||||
("wav2vec2", "Wav2Vec2Config"),
|
||||
@ -642,6 +644,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("vit_msn", "ViTMSN"),
|
||||
("vitdet", "VitDet"),
|
||||
("vitmatte", "ViTMatte"),
|
||||
("vitpose", "VitPose"),
|
||||
("vitpose_backbone", "VitPoseBackbone"),
|
||||
("vits", "VITS"),
|
||||
("vivit", "ViViT"),
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
|
@ -1396,6 +1396,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("textnet", "TextNetBackbone"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("vitdet", "VitDetBackbone"),
|
||||
("vitpose_backbone", "VitPoseBackbone"),
|
||||
]
|
||||
)
|
||||
|
||||
|
28
src/transformers/models/vitpose/__init__.py
Normal file
28
src/transformers/models/vitpose/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vitpose import *
|
||||
from .image_processing_vitpose import *
|
||||
from .modeling_vitpose import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
124
src/transformers/models/vitpose/configuration_vitpose.py
Normal file
124
src/transformers/models/vitpose/configuration_vitpose.py
Normal file
@ -0,0 +1,124 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""VitPose model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VitPoseConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VitPoseForPoseEstimation`]. It is used to instantiate a
|
||||
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the VitPose
|
||||
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
|
||||
The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
|
||||
backbone (`str`, *optional*):
|
||||
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
||||
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||
library.
|
||||
backbone_kwargs (`dict`, *optional*):
|
||||
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_factor (`int`, *optional*, defaults to 4):
|
||||
Factor to upscale the feature maps coming from the ViT backbone.
|
||||
use_simple_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VitPoseConfig, VitPoseForPoseEstimation
|
||||
|
||||
>>> # Initializing a VitPose configuration
|
||||
>>> configuration = VitPoseConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = VitPoseForPoseEstimation(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "vitpose"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone_config: PretrainedConfig = None,
|
||||
backbone: str = None,
|
||||
use_pretrained_backbone: bool = False,
|
||||
use_timm_backbone: bool = False,
|
||||
backbone_kwargs: dict = None,
|
||||
initializer_range: float = 0.02,
|
||||
scale_factor: int = 4,
|
||||
use_simple_decoder: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if use_pretrained_backbone:
|
||||
logger.info(
|
||||
"`use_pretrained_backbone` is `True`. For the pure inference purpose of VitPose weight do not set this value."
|
||||
)
|
||||
if use_timm_backbone:
|
||||
raise ValueError("use_timm_backbone set `True` is not supported at the moment.")
|
||||
|
||||
if backbone_config is None and backbone is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitPose` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["vitpose_backbone"](out_indices=[4])
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
verify_backbone_config_arguments(
|
||||
use_timm_backbone=use_timm_backbone,
|
||||
use_pretrained_backbone=use_pretrained_backbone,
|
||||
backbone=backbone,
|
||||
backbone_config=backbone_config,
|
||||
backbone_kwargs=backbone_kwargs,
|
||||
)
|
||||
|
||||
self.backbone_config = backbone_config
|
||||
self.backbone = backbone
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_kwargs = backbone_kwargs
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_factor = scale_factor
|
||||
self.use_simple_decoder = use_simple_decoder
|
||||
|
||||
|
||||
__all__ = ["VitPoseConfig"]
|
355
src/transformers/models/vitpose/convert_vitpose_to_hf.py
Normal file
355
src/transformers/models/vitpose/convert_vitpose_to_hf.py
Normal file
@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert VitPose checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/vitae-transformer/vitpose
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor
|
||||
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"patch_embed.proj": "embeddings.patch_embeddings.projection",
|
||||
r"pos_embed": "embeddings.position_embeddings",
|
||||
r"blocks": "encoder.layer",
|
||||
r"attn.proj": "attention.output.dense",
|
||||
r"attn": "attention.self",
|
||||
r"norm1": "layernorm_before",
|
||||
r"norm2": "layernorm_after",
|
||||
r"last_norm": "layernorm",
|
||||
r"keypoint_head": "head",
|
||||
r"final_layer": "conv",
|
||||
}
|
||||
|
||||
MODEL_TO_FILE_NAME_MAPPING = {
|
||||
"vitpose-base-simple": "vitpose-b-simple.pth",
|
||||
"vitpose-base": "vitpose-b.pth",
|
||||
"vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth",
|
||||
"vitpose-plus-base": "vitpose+_base.pth",
|
||||
}
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
num_experts = 6 if "plus" in model_name else 1
|
||||
part_features = 192 if "plus" in model_name else 0
|
||||
|
||||
backbone_config = VitPoseBackboneConfig(out_indices=[12], num_experts=num_experts, part_features=part_features)
|
||||
# size of the architecture
|
||||
if "small" in model_name:
|
||||
backbone_config.hidden_size = 768
|
||||
backbone_config.intermediate_size = 2304
|
||||
backbone_config.num_hidden_layers = 8
|
||||
backbone_config.num_attention_heads = 8
|
||||
elif "large" in model_name:
|
||||
backbone_config.hidden_size = 1024
|
||||
backbone_config.intermediate_size = 4096
|
||||
backbone_config.num_hidden_layers = 24
|
||||
backbone_config.num_attention_heads = 16
|
||||
elif "huge" in model_name:
|
||||
backbone_config.hidden_size = 1280
|
||||
backbone_config.intermediate_size = 5120
|
||||
backbone_config.num_hidden_layers = 32
|
||||
backbone_config.num_attention_heads = 16
|
||||
|
||||
use_simple_decoder = "simple" in model_name
|
||||
|
||||
edges = [
|
||||
[15, 13],
|
||||
[13, 11],
|
||||
[16, 14],
|
||||
[14, 12],
|
||||
[11, 12],
|
||||
[5, 11],
|
||||
[6, 12],
|
||||
[5, 6],
|
||||
[5, 7],
|
||||
[6, 8],
|
||||
[7, 9],
|
||||
[8, 10],
|
||||
[1, 2],
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 6],
|
||||
]
|
||||
id2label = {
|
||||
0: "Nose",
|
||||
1: "L_Eye",
|
||||
2: "R_Eye",
|
||||
3: "L_Ear",
|
||||
4: "R_Ear",
|
||||
5: "L_Shoulder",
|
||||
6: "R_Shoulder",
|
||||
7: "L_Elbow",
|
||||
8: "R_Elbow",
|
||||
9: "L_Wrist",
|
||||
10: "R_Wrist",
|
||||
11: "L_Hip",
|
||||
12: "R_Hip",
|
||||
13: "L_Knee",
|
||||
14: "R_Knee",
|
||||
15: "L_Ankle",
|
||||
16: "R_Ankle",
|
||||
}
|
||||
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
config = VitPoseConfig(
|
||||
backbone_config=backbone_config,
|
||||
num_labels=17,
|
||||
use_simple_decoder=use_simple_decoder,
|
||||
edges=edges,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
# We will verify our results on a COCO image
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def write_model(model_path, model_name, push_to_hub, check_logits=True):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Vision model params and config
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# params from config
|
||||
config = get_config(model_name)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Convert weights
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# load original state_dict
|
||||
filename = MODEL_TO_FILE_NAME_MAPPING[model_name]
|
||||
print(f"Fetching all parameters from the checkpoint at {filename}...")
|
||||
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id="nielsr/vitpose-original-checkpoints", filename=filename, repo_type="model"
|
||||
)
|
||||
|
||||
print("Converting model...")
|
||||
original_state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
||||
all_keys = list(original_state_dict.keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||
|
||||
dim = config.backbone_config.hidden_size
|
||||
|
||||
state_dict = {}
|
||||
for key in all_keys:
|
||||
new_key = new_keys[key]
|
||||
value = original_state_dict[key]
|
||||
|
||||
if re.search("associate_heads", new_key) or re.search("backbone.cls_token", new_key):
|
||||
# This associated_heads is concept of auxiliary head so does not require in inference stage.
|
||||
# backbone.cls_token is optional forward function for dynamically change of size, see detail in https://github.com/ViTAE-Transformer/ViTPose/issues/34
|
||||
pass
|
||||
elif re.search("qkv", new_key):
|
||||
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
|
||||
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
|
||||
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]
|
||||
elif re.search("head", new_key) and not config.use_simple_decoder:
|
||||
# Pattern for deconvolution layers
|
||||
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
|
||||
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1))//3 + 1}.weight", new_key)
|
||||
# Pattern for batch normalization layers
|
||||
bn_patterns = [
|
||||
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
|
||||
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
|
||||
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
|
||||
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
|
||||
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
|
||||
]
|
||||
|
||||
for pattern, replacement in bn_patterns:
|
||||
if re.search(pattern, new_key):
|
||||
# Convert the layer number to the correct batch norm index
|
||||
layer_num = int(re.search(pattern, key).group(1))
|
||||
bn_num = layer_num // 3 + 1
|
||||
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
|
||||
state_dict[new_key] = value
|
||||
else:
|
||||
state_dict[new_key] = value
|
||||
|
||||
print("Loading the checkpoint in a Vitpose model.")
|
||||
model = VitPoseForPoseEstimation(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
# create image processor
|
||||
image_processor = VitPoseImageProcessor()
|
||||
|
||||
# verify image processor
|
||||
image = prepare_img()
|
||||
boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
|
||||
pixel_values = image_processor(images=image, boxes=boxes, return_tensors="pt").pixel_values
|
||||
|
||||
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset")
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu")["img"]
|
||||
assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1)
|
||||
|
||||
dataset_index = torch.tensor([0])
|
||||
|
||||
with torch.no_grad():
|
||||
# first forward pass
|
||||
outputs = model(pixel_values, dataset_index=dataset_index)
|
||||
output_heatmap = outputs.heatmaps
|
||||
|
||||
# second forward pass (flipped)
|
||||
# this is done since the model uses `flip_test=True` in its test config
|
||||
pixel_values_flipped = torch.flip(pixel_values, [3])
|
||||
outputs_flipped = model(
|
||||
pixel_values_flipped,
|
||||
dataset_index=dataset_index,
|
||||
flip_pairs=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]),
|
||||
)
|
||||
output_flipped_heatmap = outputs_flipped.heatmaps
|
||||
|
||||
outputs.heatmaps = (output_heatmap + output_flipped_heatmap) * 0.5
|
||||
|
||||
# Verify pose_results
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
|
||||
|
||||
if check_logits:
|
||||
if model_name == "vitpose-base-simple":
|
||||
assert torch.allclose(
|
||||
pose_results[1]["keypoints"][0],
|
||||
torch.tensor([3.98180511e02, 1.81808380e02]),
|
||||
atol=5e-2,
|
||||
)
|
||||
assert torch.allclose(
|
||||
pose_results[1]["scores"][0],
|
||||
torch.tensor([8.66642594e-01]),
|
||||
atol=5e-2,
|
||||
)
|
||||
elif model_name == "vitpose-base":
|
||||
assert torch.allclose(
|
||||
pose_results[1]["keypoints"][0],
|
||||
torch.tensor([3.9807913e02, 1.8182812e02]),
|
||||
atol=5e-2,
|
||||
)
|
||||
assert torch.allclose(
|
||||
pose_results[1]["scores"][0],
|
||||
torch.tensor([8.8235235e-01]),
|
||||
atol=5e-2,
|
||||
)
|
||||
elif model_name == "vitpose-base-coco-aic-mpii":
|
||||
assert torch.allclose(
|
||||
pose_results[1]["keypoints"][0],
|
||||
torch.tensor([3.98305542e02, 1.81741592e02]),
|
||||
atol=5e-2,
|
||||
)
|
||||
assert torch.allclose(
|
||||
pose_results[1]["scores"][0],
|
||||
torch.tensor([8.69966745e-01]),
|
||||
atol=5e-2,
|
||||
)
|
||||
elif model_name == "vitpose-plus-base":
|
||||
assert torch.allclose(
|
||||
pose_results[1]["keypoints"][0],
|
||||
torch.tensor([3.98201294e02, 1.81728302e02]),
|
||||
atol=5e-2,
|
||||
)
|
||||
assert torch.allclose(
|
||||
pose_results[1]["scores"][0],
|
||||
torch.tensor([8.75046968e-01]),
|
||||
atol=5e-2,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
print("Conversion successfully done.")
|
||||
|
||||
# save the model to a local directory
|
||||
model.save_pretrained(model_path)
|
||||
image_processor.save_pretrained(model_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and image processor for {model_name} to hub")
|
||||
model.push_to_hub(f"danelcsb/{model_name}")
|
||||
image_processor.push_to_hub(f"danelcsb/{model_name}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="vitpose-base-simple",
|
||||
choices=MODEL_TO_FILE_NAME_MAPPING.keys(),
|
||||
type=str,
|
||||
help="Name of the VitPose model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to check the logits of public converted model to the 🤗 hub. You can disable when using custom model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.pytorch_dump_folder_path,
|
||||
model_name=args.model_name,
|
||||
push_to_hub=args.push_to_hub,
|
||||
check_logits=args.check_logits,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
684
src/transformers/models/vitpose/image_processing_vitpose.py
Normal file
684
src/transformers/models/vitpose/image_processing_vitpose.py
Normal file
@ -0,0 +1,684 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for VitPose."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from ...utils import TensorType, is_scipy_available, is_torch_available, is_vision_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.linalg import inv
|
||||
from scipy.ndimage import affine_transform, gaussian_filter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_vitpose import VitPoseEstimatorOutput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# inspired by https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py#L132
|
||||
def box_to_center_and_scale(
|
||||
box: Union[Tuple, List, np.ndarray],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
normalize_factor: float = 200.0,
|
||||
padding_factor: float = 1.25,
|
||||
):
|
||||
"""
|
||||
Encodes a bounding box in COCO format into (center, scale).
|
||||
|
||||
Args:
|
||||
box (`Tuple`, `List`, or `np.ndarray`):
|
||||
Bounding box in COCO format (top_left_x, top_left_y, width, height).
|
||||
image_width (`int`):
|
||||
Image width.
|
||||
image_height (`int`):
|
||||
Image height.
|
||||
normalize_factor (`float`):
|
||||
Width and height scale factor.
|
||||
padding_factor (`float`):
|
||||
Bounding box padding factor.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
|
||||
- `np.ndarray` [float32](2,): Center of the bbox (x, y).
|
||||
- `np.ndarray` [float32](2,): Scale of the bbox width & height.
|
||||
"""
|
||||
|
||||
top_left_x, top_left_y, width, height = box[:4]
|
||||
aspect_ratio = image_width / image_height
|
||||
center = np.array([top_left_x + width * 0.5, top_left_y + height * 0.5], dtype=np.float32)
|
||||
|
||||
if width > aspect_ratio * height:
|
||||
height = width * 1.0 / aspect_ratio
|
||||
elif width < aspect_ratio * height:
|
||||
width = height * aspect_ratio
|
||||
|
||||
scale = np.array([width / normalize_factor, height / normalize_factor], dtype=np.float32)
|
||||
scale = scale * padding_factor
|
||||
|
||||
return center, scale
|
||||
|
||||
|
||||
def coco_to_pascal_voc(bboxes: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Converts bounding boxes from the COCO format to the Pascal VOC format.
|
||||
|
||||
In other words, converts from (top_left_x, top_left_y, width, height) format
|
||||
to (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
|
||||
|
||||
Args:
|
||||
bboxes (`np.ndarray` of shape `(batch_size, 4)):
|
||||
Bounding boxes in COCO format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of shape `(batch_size, 4) in Pascal VOC format.
|
||||
"""
|
||||
bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0] - 1
|
||||
bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1] - 1
|
||||
|
||||
return bboxes
|
||||
|
||||
|
||||
def get_keypoint_predictions(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get keypoint predictions from score maps.
|
||||
|
||||
Args:
|
||||
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
|
||||
Model predicted heatmaps.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing aggregated results.
|
||||
|
||||
- coords (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
|
||||
Predicted keypoint location.
|
||||
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
|
||||
Scores (confidence) of the keypoints.
|
||||
"""
|
||||
if not isinstance(heatmaps, np.ndarray):
|
||||
raise ValueError("Heatmaps should be np.ndarray")
|
||||
if heatmaps.ndim != 4:
|
||||
raise ValueError("Heatmaps should be 4-dimensional")
|
||||
|
||||
batch_size, num_keypoints, _, width = heatmaps.shape
|
||||
heatmaps_reshaped = heatmaps.reshape((batch_size, num_keypoints, -1))
|
||||
idx = np.argmax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
|
||||
scores = np.amax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
|
||||
|
||||
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
||||
preds[:, :, 0] = preds[:, :, 0] % width
|
||||
preds[:, :, 1] = preds[:, :, 1] // width
|
||||
|
||||
preds = np.where(np.tile(scores, (1, 1, 2)) > 0.0, preds, -1)
|
||||
return preds, scores
|
||||
|
||||
|
||||
def post_dark_unbiased_data_processing(coords: np.ndarray, batch_heatmaps: np.ndarray, kernel: int = 3) -> np.ndarray:
|
||||
"""DARK post-pocessing. Implemented by unbiased_data_processing.
|
||||
|
||||
Paper references:
|
||||
- Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
- Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Args:
|
||||
coords (`np.ndarray` of shape `(num_persons, num_keypoints, 2)`):
|
||||
Initial coordinates of human pose.
|
||||
batch_heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
|
||||
Batched heatmaps as predicted by the model.
|
||||
A batch_size of 1 is used for the bottom up paradigm where all persons share the same heatmap.
|
||||
A batch_size of `num_persons` is used for the top down paradigm where each person has its own heatmaps.
|
||||
kernel (`int`, *optional*, defaults to 3):
|
||||
Gaussian kernel size (K) for modulation.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of shape `(num_persons, num_keypoints, 2)` ):
|
||||
Refined coordinates.
|
||||
"""
|
||||
batch_size, num_keypoints, height, width = batch_heatmaps.shape
|
||||
num_coords = coords.shape[0]
|
||||
if not (batch_size == 1 or batch_size == num_coords):
|
||||
raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.")
|
||||
radius = int((kernel - 1) // 2)
|
||||
batch_heatmaps = np.array(
|
||||
[
|
||||
[gaussian_filter(heatmap, sigma=0.8, radius=(radius, radius), axes=(0, 1)) for heatmap in heatmaps]
|
||||
for heatmaps in batch_heatmaps
|
||||
]
|
||||
)
|
||||
batch_heatmaps = np.clip(batch_heatmaps, 0.001, 50)
|
||||
batch_heatmaps = np.log(batch_heatmaps)
|
||||
|
||||
batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten()
|
||||
|
||||
# calculate indices for coordinates
|
||||
index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (width + 2)
|
||||
index += (width + 2) * (height + 2) * np.arange(0, batch_size * num_keypoints).reshape(-1, num_keypoints)
|
||||
index = index.astype(int).reshape(-1, 1)
|
||||
i_ = batch_heatmaps_pad[index]
|
||||
ix1 = batch_heatmaps_pad[index + 1]
|
||||
iy1 = batch_heatmaps_pad[index + width + 2]
|
||||
ix1y1 = batch_heatmaps_pad[index + width + 3]
|
||||
ix1_y1_ = batch_heatmaps_pad[index - width - 3]
|
||||
ix1_ = batch_heatmaps_pad[index - 1]
|
||||
iy1_ = batch_heatmaps_pad[index - 2 - width]
|
||||
|
||||
# calculate refined coordinates using Newton's method
|
||||
dx = 0.5 * (ix1 - ix1_)
|
||||
dy = 0.5 * (iy1 - iy1_)
|
||||
derivative = np.concatenate([dx, dy], axis=1)
|
||||
derivative = derivative.reshape(num_coords, num_keypoints, 2, 1)
|
||||
dxx = ix1 - 2 * i_ + ix1_
|
||||
dyy = iy1 - 2 * i_ + iy1_
|
||||
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
||||
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
|
||||
hessian = hessian.reshape(num_coords, num_keypoints, 2, 2)
|
||||
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
||||
coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze()
|
||||
return coords
|
||||
|
||||
|
||||
def transform_preds(coords: np.ndarray, center: np.ndarray, scale: np.ndarray, output_size: np.ndarray) -> np.ndarray:
|
||||
"""Get final keypoint predictions from heatmaps and apply scaling and
|
||||
translation to map them back to the image.
|
||||
|
||||
Note:
|
||||
num_keypoints: K
|
||||
|
||||
Args:
|
||||
coords (`np.ndarray` of shape `(num_keypoints, ndims)`):
|
||||
|
||||
* If ndims=2, corrds are predicted keypoint location.
|
||||
* If ndims=4, corrds are composed of (x, y, scores, tags)
|
||||
* If ndims=5, corrds are composed of (x, y, scores, tags,
|
||||
flipped_tags)
|
||||
|
||||
center (`np.ndarray` of shape `(2,)`):
|
||||
Center of the bounding box (x, y).
|
||||
scale (`np.ndarray` of shape `(2,)`):
|
||||
Scale of the bounding box wrt original image of width and height.
|
||||
output_size (`np.ndarray` of shape `(2,)`):
|
||||
Size of the destination heatmaps in (height, width) format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Predicted coordinates in the images.
|
||||
"""
|
||||
if coords.shape[1] not in (2, 4, 5):
|
||||
raise ValueError("Coordinates need to have either 2, 4 or 5 dimensions.")
|
||||
if len(center) != 2:
|
||||
raise ValueError("Center needs to have 2 elements, one for x and one for y.")
|
||||
if len(scale) != 2:
|
||||
raise ValueError("Scale needs to consist of a width and height")
|
||||
if len(output_size) != 2:
|
||||
raise ValueError("Output size needs to consist of a height and width")
|
||||
|
||||
# Recover the scale which is normalized by a factor of 200.
|
||||
scale = scale * 200.0
|
||||
|
||||
# We use unbiased data processing
|
||||
scale_y = scale[1] / (output_size[0] - 1.0)
|
||||
scale_x = scale[0] / (output_size[1] - 1.0)
|
||||
|
||||
target_coords = np.ones_like(coords)
|
||||
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
|
||||
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
|
||||
|
||||
return target_coords
|
||||
|
||||
|
||||
def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray, size_target: np.ndarray):
|
||||
"""
|
||||
Calculate the transformation matrix under the constraint of unbiased. Paper ref: Huang et al. The Devil is in the
|
||||
Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Source: https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
||||
|
||||
Args:
|
||||
theta (`float`):
|
||||
Rotation angle in degrees.
|
||||
size_input (`np.ndarray`):
|
||||
Size of input image [width, height].
|
||||
size_dst (`np.ndarray`):
|
||||
Size of output image [width, height].
|
||||
size_target (`np.ndarray`):
|
||||
Size of ROI in input plane [w, h].
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: A matrix for transformation.
|
||||
"""
|
||||
theta = np.deg2rad(theta)
|
||||
matrix = np.zeros((2, 3), dtype=np.float32)
|
||||
scale_x = size_dst[0] / size_target[0]
|
||||
scale_y = size_dst[1] / size_target[1]
|
||||
matrix[0, 0] = math.cos(theta) * scale_x
|
||||
matrix[0, 1] = -math.sin(theta) * scale_x
|
||||
matrix[0, 2] = scale_x * (
|
||||
-0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * math.sin(theta) + 0.5 * size_target[0]
|
||||
)
|
||||
matrix[1, 0] = math.sin(theta) * scale_y
|
||||
matrix[1, 1] = math.cos(theta) * scale_y
|
||||
matrix[1, 2] = scale_y * (
|
||||
-0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * math.cos(theta) + 0.5 * size_target[1]
|
||||
)
|
||||
return matrix
|
||||
|
||||
|
||||
def scipy_warp_affine(src, M, size):
|
||||
"""
|
||||
This function implements cv2.warpAffine function using affine_transform in scipy. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html and https://docs.opencv.org/4.x/d4/d61/tutorial_warp_affine.html for more details.
|
||||
|
||||
Note: the original implementation of cv2.warpAffine uses cv2.INTER_LINEAR.
|
||||
"""
|
||||
channels = [src[..., i] for i in range(src.shape[-1])]
|
||||
|
||||
# Convert to a 3x3 matrix used by SciPy
|
||||
M_scipy = np.vstack([M, [0, 0, 1]])
|
||||
# If you have a matrix for the ‘push’ transformation, use its inverse (numpy.linalg.inv) in this function.
|
||||
M_inv = inv(M_scipy)
|
||||
M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = (
|
||||
M_inv[1, 1],
|
||||
M_inv[1, 0],
|
||||
M_inv[0, 1],
|
||||
M_inv[0, 0],
|
||||
M_inv[1, 2],
|
||||
M_inv[0, 2],
|
||||
)
|
||||
|
||||
new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels]
|
||||
new_src = np.stack(new_src, axis=-1)
|
||||
return new_src
|
||||
|
||||
|
||||
class VitPoseImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a VitPose image processor.
|
||||
|
||||
Args:
|
||||
do_affine_transform (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply an affine transformation to the input images.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 192}`):
|
||||
Resolution of the image after `affine_transform` is applied. Only has an effect if `do_affine_transform` is set to `True`. Can
|
||||
be overriden by `size` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`, *optional*):
|
||||
The sequence of means for each channel, to be used when normalizing images.
|
||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`, *optional*):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_affine_transform: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.do_affine_transform = do_affine_transform
|
||||
self.size = size if size is not None else {"height": 256, "width": 192}
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||
self.normalize_factor = 200.0
|
||||
|
||||
def affine_transform(
|
||||
self,
|
||||
image: np.array,
|
||||
center: Tuple[float],
|
||||
scale: Tuple[float],
|
||||
rotation: float,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Apply an affine transformation to an image.
|
||||
|
||||
Args:
|
||||
image (`np.array`):
|
||||
Image to transform.
|
||||
center (`Tuple[float]`):
|
||||
Center of the bounding box (x, y).
|
||||
scale (`Tuple[float]`):
|
||||
Scale of the bounding box with respect to height/width.
|
||||
rotation (`float`):
|
||||
Rotation angle in degrees.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the destination image.
|
||||
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format of the output image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image.
|
||||
"""
|
||||
|
||||
data_format = input_data_format if data_format is None else data_format
|
||||
|
||||
size = (size["width"], size["height"])
|
||||
|
||||
# one uses a pixel standard deviation of 200 pixels
|
||||
transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0)
|
||||
|
||||
# input image requires channels last format
|
||||
image = (
|
||||
image
|
||||
if input_data_format == ChannelDimension.LAST
|
||||
else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
||||
)
|
||||
image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0]))
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST)
|
||||
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
boxes: Union[List[List[float]], np.ndarray],
|
||||
do_affine_transform: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
|
||||
boxes (`List[List[List[float]]]` or `np.ndarray`):
|
||||
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
|
||||
box coordinates in COCO format (top_left_x, top_left_y, width, height).
|
||||
|
||||
do_affine_transform (`bool`, *optional*, defaults to `self.do_affine_transform`):
|
||||
Whether to apply an affine transformation to the input images.
|
||||
size (`Dict[str, int]` *optional*, defaults to `self.size`):
|
||||
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
||||
resizing.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use if `do_normalize` is set to `True`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
||||
width).
|
||||
"""
|
||||
do_affine_transform = do_affine_transform if do_affine_transform is not None else self.do_affine_transform
|
||||
size = size if size is not None else self.size
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if isinstance(boxes, list) and len(images) != len(boxes):
|
||||
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}")
|
||||
elif isinstance(boxes, np.ndarray) and len(images) != boxes.shape[0]:
|
||||
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {boxes.shape[0]}")
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# transformations (affine transformation + rescaling + normalization)
|
||||
if self.do_affine_transform:
|
||||
new_images = []
|
||||
for image, image_boxes in zip(images, boxes):
|
||||
for box in image_boxes:
|
||||
center, scale = box_to_center_and_scale(
|
||||
box,
|
||||
image_width=size["width"],
|
||||
image_height=size["height"],
|
||||
normalize_factor=self.normalize_factor,
|
||||
)
|
||||
transformed_image = self.affine_transform(
|
||||
image, center, scale, rotation=0, size=size, input_data_format=input_data_format
|
||||
)
|
||||
new_images.append(transformed_image)
|
||||
images = new_images
|
||||
|
||||
# For batch processing, the number of boxes must be consistent across all images in the batch.
|
||||
# When using a list input, the number of boxes can vary dynamically per image.
|
||||
# The image processor creates pixel_values of shape (batch_size*num_persons, num_channels, height, width)
|
||||
|
||||
all_images = []
|
||||
for image in images:
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
all_images.append(image)
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in all_images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def keypoints_from_heatmaps(
|
||||
self,
|
||||
heatmaps: np.ndarray,
|
||||
center: np.ndarray,
|
||||
scale: np.ndarray,
|
||||
kernel: int = 11,
|
||||
):
|
||||
"""
|
||||
Get final keypoint predictions from heatmaps and transform them back to
|
||||
the image.
|
||||
|
||||
Args:
|
||||
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width])`):
|
||||
Model predicted heatmaps.
|
||||
center (`np.ndarray` of shape `(batch_size, 2)`):
|
||||
Center of the bounding box (x, y).
|
||||
scale (`np.ndarray` of shape `(batch_size, 2)`):
|
||||
Scale of the bounding box wrt original images of width and height.
|
||||
kernel (int, *optional*, defaults to 11):
|
||||
Gaussian kernel size (K) for modulation, which should match the heatmap gaussian sigma when training.
|
||||
K=17 for sigma=3 and k=11 for sigma=2.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing keypoint predictions and scores.
|
||||
|
||||
- preds (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
|
||||
Predicted keypoint location in images.
|
||||
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
|
||||
Scores (confidence) of the keypoints.
|
||||
"""
|
||||
batch_size, _, height, width = heatmaps.shape
|
||||
|
||||
coords, scores = get_keypoint_predictions(heatmaps)
|
||||
|
||||
preds = post_dark_unbiased_data_processing(coords, heatmaps, kernel=kernel)
|
||||
|
||||
# Transform back to the image
|
||||
for i in range(batch_size):
|
||||
preds[i] = transform_preds(preds[i], center=center[i], scale=scale[i], output_size=[height, width])
|
||||
|
||||
return preds, scores
|
||||
|
||||
def post_process_pose_estimation(
|
||||
self,
|
||||
outputs: "VitPoseEstimatorOutput",
|
||||
boxes: Union[List[List[List[float]]], np.ndarray],
|
||||
kernel_size: int = 11,
|
||||
threshold: float = None,
|
||||
target_sizes: Union[TensorType, List[Tuple]] = None,
|
||||
):
|
||||
"""
|
||||
Transform the heatmaps into keypoint predictions and transform them back to the image.
|
||||
|
||||
Args:
|
||||
outputs (`VitPoseEstimatorOutput`):
|
||||
VitPoseForPoseEstimation model outputs.
|
||||
boxes (`List[List[List[float]]]` or `np.ndarray`):
|
||||
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
|
||||
box coordinates in COCO format (top_left_x, top_left_y, width, height).
|
||||
kernel_size (`int`, *optional*, defaults to 11):
|
||||
Gaussian kernel size (K) for modulation.
|
||||
threshold (`float`, *optional*, defaults to None):
|
||||
Score threshold to keep object detection predictions.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will be resize with the default value.
|
||||
Returns:
|
||||
`List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
|
||||
# First compute centers and scales for each bounding box
|
||||
batch_size, num_keypoints, _, _ = outputs.heatmaps.shape
|
||||
|
||||
if target_sizes is not None:
|
||||
if batch_size != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
centers = np.zeros((batch_size, 2), dtype=np.float32)
|
||||
scales = np.zeros((batch_size, 2), dtype=np.float32)
|
||||
flattened_boxes = list(itertools.chain(*boxes))
|
||||
for i in range(batch_size):
|
||||
if target_sizes is not None:
|
||||
image_width, image_height = target_sizes[i][0], target_sizes[i][1]
|
||||
scale_factor = np.array([image_width, image_height, image_width, image_height])
|
||||
flattened_boxes[i] = flattened_boxes[i] * scale_factor
|
||||
width, height = self.size["width"], self.size["height"]
|
||||
center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height)
|
||||
centers[i, :] = center
|
||||
scales[i, :] = scale
|
||||
|
||||
preds, scores = self.keypoints_from_heatmaps(
|
||||
outputs.heatmaps.cpu().numpy(), centers, scales, kernel=kernel_size
|
||||
)
|
||||
|
||||
all_boxes = np.zeros((batch_size, 4), dtype=np.float32)
|
||||
all_boxes[:, 0:2] = centers[:, 0:2]
|
||||
all_boxes[:, 2:4] = scales[:, 0:2]
|
||||
|
||||
poses = torch.tensor(preds)
|
||||
scores = torch.tensor(scores)
|
||||
labels = torch.arange(0, num_keypoints)
|
||||
bboxes_xyxy = torch.tensor(coco_to_pascal_voc(all_boxes))
|
||||
|
||||
results: List[List[Dict[str, torch.Tensor]]] = []
|
||||
|
||||
pose_bbox_pairs = zip(poses, scores, bboxes_xyxy)
|
||||
|
||||
for image_bboxes in boxes:
|
||||
image_results: List[Dict[str, torch.Tensor]] = []
|
||||
for _ in image_bboxes:
|
||||
# Unpack the next pose and bbox_xyxy from the iterator
|
||||
pose, score, bbox_xyxy = next(pose_bbox_pairs)
|
||||
score = score.squeeze()
|
||||
keypoints_labels = labels
|
||||
if threshold is not None:
|
||||
keep = score > threshold
|
||||
pose = pose[keep]
|
||||
score = score[keep]
|
||||
keypoints_labels = keypoints_labels[keep]
|
||||
pose_result = {"keypoints": pose, "scores": score, "labels": keypoints_labels, "bbox": bbox_xyxy}
|
||||
image_results.append(pose_result)
|
||||
results.append(image_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
__all__ = ["VitPoseImageProcessor"]
|
340
src/transformers/models/vitpose/modeling_vitpose.py
Normal file
340
src/transformers/models/vitpose/modeling_vitpose.py
Normal file
@ -0,0 +1,340 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch VitPose model."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_vitpose import VitPoseConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "VitPoseConfig"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitPoseEstimatorOutput(ModelOutput):
|
||||
"""
|
||||
Class for outputs of pose estimation models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
|
||||
heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
|
||||
Heatmaps as predicted by the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
|
||||
(also called feature maps) of the model at the output of each stage.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
heatmaps: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
class VitPosePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = VitPoseConfig
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
VITPOSE_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`VitPoseConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
VITPOSE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`VitPoseImageProcessor`]. See
|
||||
[`VitPoseImageProcessor.__call__`] for details.
|
||||
|
||||
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
|
||||
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
|
||||
|
||||
This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
|
||||
|
||||
flip_pairs (`torch.tensor`, *optional*):
|
||||
Whether to mirror pairs of keypoints (for example, left ear -- right ear).
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
|
||||
"""Flip the flipped heatmaps back to the original form.
|
||||
|
||||
Args:
|
||||
output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
|
||||
The output heatmaps obtained from the flipped images.
|
||||
flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
|
||||
Pairs of keypoints which are mirrored (for example, left ear -- right ear).
|
||||
target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
|
||||
Target type to use. Can be gaussian-heatmap or combined-target.
|
||||
gaussian-heatmap: Classification target with gaussian distribution.
|
||||
combined-target: The combination of classification target (response map) and regression target (offset map).
|
||||
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: heatmaps that flipped back to the original image
|
||||
"""
|
||||
if target_type not in ["gaussian-heatmap", "combined-target"]:
|
||||
raise ValueError("target_type should be gaussian-heatmap or combined-target")
|
||||
|
||||
if output_flipped.ndim != 4:
|
||||
raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
|
||||
batch_size, num_keypoints, height, width = output_flipped.shape
|
||||
channels = 1
|
||||
if target_type == "combined-target":
|
||||
channels = 3
|
||||
output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
|
||||
output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
|
||||
output_flipped_back = output_flipped.clone()
|
||||
|
||||
# Swap left-right parts
|
||||
for left, right in flip_pairs.tolist():
|
||||
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
|
||||
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
|
||||
output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
|
||||
# Flip horizontally
|
||||
output_flipped_back = output_flipped_back.flip(-1)
|
||||
return output_flipped_back
|
||||
|
||||
|
||||
class VitPoseSimpleDecoder(nn.Module):
|
||||
"""
|
||||
Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
|
||||
feature maps into heatmaps.
|
||||
"""
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
|
||||
self.conv = nn.Conv2d(
|
||||
config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# Transform input: ReLU + upsample
|
||||
hidden_state = self.activation(hidden_state)
|
||||
hidden_state = self.upsampling(hidden_state)
|
||||
heatmaps = self.conv(hidden_state)
|
||||
|
||||
if flip_pairs is not None:
|
||||
heatmaps = flip_back(heatmaps, flip_pairs)
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
class VitPoseClassicDecoder(nn.Module):
|
||||
"""
|
||||
Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
|
||||
turning the feature maps into heatmaps.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VitPoseConfig):
|
||||
super().__init__()
|
||||
|
||||
self.deconv1 = nn.ConvTranspose2d(
|
||||
config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
|
||||
)
|
||||
self.batchnorm1 = nn.BatchNorm2d(256)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
|
||||
self.batchnorm2 = nn.BatchNorm2d(256)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
|
||||
hidden_state = self.deconv1(hidden_state)
|
||||
hidden_state = self.batchnorm1(hidden_state)
|
||||
hidden_state = self.relu1(hidden_state)
|
||||
|
||||
hidden_state = self.deconv2(hidden_state)
|
||||
hidden_state = self.batchnorm2(hidden_state)
|
||||
hidden_state = self.relu2(hidden_state)
|
||||
|
||||
heatmaps = self.conv(hidden_state)
|
||||
|
||||
if flip_pairs is not None:
|
||||
heatmaps = flip_back(heatmaps, flip_pairs)
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The VitPose model with a pose estimation head on top.",
|
||||
VITPOSE_START_DOCSTRING,
|
||||
)
|
||||
class VitPoseForPoseEstimation(VitPosePreTrainedModel):
|
||||
def __init__(self, config: VitPoseConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.backbone = load_backbone(config)
|
||||
|
||||
# add backbone attributes
|
||||
if not hasattr(self.backbone.config, "hidden_size"):
|
||||
raise ValueError("The backbone should have a hidden_size attribute")
|
||||
if not hasattr(self.backbone.config, "image_size"):
|
||||
raise ValueError("The backbone should have an image_size attribute")
|
||||
if not hasattr(self.backbone.config, "patch_size"):
|
||||
raise ValueError("The backbone should have a patch_size attribute")
|
||||
|
||||
self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VITPOSE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=VitPoseEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
dataset_index: Optional[torch.Tensor] = None,
|
||||
flip_pairs: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, VitPoseEstimatorOutput]:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
>>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
|
||||
>>> inputs = processor(image, boxes=boxes, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
>>> heatmaps = outputs.heatmaps
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
raise NotImplementedError("Training is not yet supported")
|
||||
|
||||
outputs = self.backbone.forward_with_filtered_kwargs(
|
||||
pixel_values,
|
||||
dataset_index=dataset_index,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
|
||||
sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1]
|
||||
batch_size = sequence_output.shape[0]
|
||||
patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
|
||||
patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
|
||||
sequence_output = (
|
||||
sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous()
|
||||
)
|
||||
|
||||
heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (heatmaps,) + outputs[1:]
|
||||
else:
|
||||
output = (heatmaps,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return VitPoseEstimatorOutput(
|
||||
loss=loss,
|
||||
heatmaps=heatmaps,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]
|
54
src/transformers/models/vitpose_backbone/__init__.py
Normal file
54
src/transformers/models/vitpose_backbone/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_vitpose_backbone": ["VitPoseBackboneConfig"]}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vitpose_backbone"] = [
|
||||
"VitPoseBackbonePreTrainedModel",
|
||||
"VitPoseBackbone",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vitpose_backbone import VitPoseBackboneConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_vitpose_backbone import (
|
||||
VitPoseBackbone,
|
||||
VitPoseBackbonePreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@ -0,0 +1,136 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""VitPose backbone configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
|
||||
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the VitPose
|
||||
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to `[256, 192]`):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
mlp_ratio (`int`, *optional*, defaults to 4):
|
||||
The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
|
||||
num_experts (`int`, *optional*, defaults to 1):
|
||||
The number of experts in the MoE layer.
|
||||
part_features (`int`, *optional*):
|
||||
The number of part features to output. Only used in case `num_experts` is greater than 1.
|
||||
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
||||
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
||||
same order as defined in the `stage_names` attribute.
|
||||
out_indices (`List[int]`, *optional*):
|
||||
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
||||
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
||||
same order as defined in the `stage_names` attribute.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
|
||||
|
||||
>>> # Initializing a VitPose configuration
|
||||
>>> configuration = VitPoseBackboneConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = VitPoseBackbone(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "vitpose_backbone"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=[256, 192],
|
||||
patch_size=[16, 16],
|
||||
num_channels=3,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
mlp_ratio=4,
|
||||
num_experts=1,
|
||||
part_features=256,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
qkv_bias=True,
|
||||
out_features=None,
|
||||
out_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_experts = num_experts
|
||||
self.part_features = part_features
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
@ -0,0 +1,542 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch VitPose backbone model.
|
||||
|
||||
This code is the same as the original Vision Transformer (ViT) with 2 modifications:
|
||||
- use of padding=2 in the patch embedding layer
|
||||
- addition of a mixture-of-experts MLP layer
|
||||
"""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_vitpose_backbone import VitPoseBackboneConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "VitPoseBackboneConfig"
|
||||
|
||||
|
||||
class VitPoseBackbonePatchEmbeddings(nn.Module):
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
image_size = config.image_size
|
||||
patch_size = config.patch_size
|
||||
num_channels = config.num_channels
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
height, width = pixel_values.shape[-2:]
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values)
|
||||
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
return embeddings
|
||||
|
||||
|
||||
class VitPoseBackboneEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the position and patch embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
# add positional encoding to each token
|
||||
embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
|
||||
class VitPoseBackboneSelfAttention(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
|
||||
class VitPoseBackboneSelfOutput(nn.Module):
|
||||
"""
|
||||
The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
|
||||
class VitPoseBackboneAttention(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
self.attention = VitPoseBackboneSelfAttention(config)
|
||||
self.output = VitPoseBackboneSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads: Set[int]) -> None:
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.attention.query = prune_linear_layer(self.attention.query, index)
|
||||
self.attention.key = prune_linear_layer(self.attention.key, index)
|
||||
self.attention.value = prune_linear_layer(self.attention.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
||||
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class VitPoseBackboneMoeMLP(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig):
|
||||
super().__init__()
|
||||
|
||||
in_features = out_features = config.hidden_size
|
||||
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
||||
|
||||
num_experts = config.num_experts
|
||||
part_features = config.part_features
|
||||
|
||||
self.part_features = part_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.fc2 = nn.Linear(hidden_features, out_features - part_features)
|
||||
self.drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.num_experts = num_experts
|
||||
experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
|
||||
self.experts = nn.ModuleList(experts)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
|
||||
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = self.act(hidden_state)
|
||||
shared_hidden_state = self.fc2(hidden_state)
|
||||
indices = indices.view(-1, 1, 1)
|
||||
|
||||
# to support ddp training
|
||||
for i in range(self.num_experts):
|
||||
selected_index = indices == i
|
||||
current_hidden_state = self.experts[i](hidden_state) * selected_index
|
||||
expert_hidden_state = expert_hidden_state + current_hidden_state
|
||||
|
||||
hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VitPoseBackboneMLP(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
in_features = out_features = config.hidden_size
|
||||
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = self.activation(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VitPoseBackboneLayer(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.attention = VitPoseBackboneAttention(config)
|
||||
self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
dataset_index: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
# Validate dataset_index when using multiple experts
|
||||
if self.num_experts > 1 and dataset_index is None:
|
||||
raise ValueError(
|
||||
"dataset_index must be provided when using multiple experts "
|
||||
f"(num_experts={self.num_experts}). Please provide dataset_index "
|
||||
"to the forward pass."
|
||||
)
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
if self.num_experts == 1:
|
||||
layer_output = self.mlp(layer_output)
|
||||
else:
|
||||
layer_output = self.mlp(layer_output, indices=dataset_index)
|
||||
|
||||
# second residual connection
|
||||
layer_output = layer_output + hidden_states
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
|
||||
class VitPoseBackboneEncoder(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
dataset_index: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
dataset_index,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class VitPoseBackbonePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = VitPoseBackboneConfig
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, VitPoseBackboneEmbeddings):
|
||||
module.position_embeddings.data = nn.init.trunc_normal_(
|
||||
module.position_embeddings.data.to(torch.float32),
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
).to(module.position_embeddings.dtype)
|
||||
|
||||
|
||||
VITPOSE_BACKBONE_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values.
|
||||
|
||||
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
|
||||
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
|
||||
|
||||
This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
|
||||
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The VitPose backbone useful for downstream tasks.",
|
||||
VITPOSE_BACKBONE_START_DOCSTRING,
|
||||
)
|
||||
class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config: VitPoseBackboneConfig):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
|
||||
self.embeddings = VitPoseBackboneEmbeddings(config)
|
||||
self.encoder = VitPoseBackboneEncoder(config)
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
dataset_index: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
|
||||
>>> import torch
|
||||
|
||||
>>> config = VitPoseBackboneConfig(out_indices=[-1])
|
||||
>>> model = VitPoseBackbone(config)
|
||||
|
||||
>>> pixel_values = torch.randn(1, 3, 256, 192)
|
||||
>>> dataset_index = torch.tensor([1])
|
||||
>>> outputs = model(pixel_values, dataset_index)
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values)
|
||||
|
||||
outputs = self.encoder(
|
||||
embedding_output,
|
||||
dataset_index=dataset_index,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||||
|
||||
feature_maps = ()
|
||||
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||||
if stage in self.out_features:
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
feature_maps += (hidden_state,)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (feature_maps,) + outputs[1:]
|
||||
else:
|
||||
output = (feature_maps,) + outputs[2:]
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -9753,6 +9753,34 @@ class VitMattePreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VitPoseForPoseEstimation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VitPosePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VitPoseBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VitPoseBackbonePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VitsModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -695,6 +695,13 @@ class VitMatteImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class VitPoseImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class VivitImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/vitpose/__init__.py
Normal file
0
tests/models/vitpose/__init__.py
Normal file
229
tests/models/vitpose/test_image_processing_vitpose.py
Normal file
229
tests/models/vitpose/test_image_processing_vitpose.py
Normal file
@ -0,0 +1,229 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VitPoseImageProcessor
|
||||
|
||||
|
||||
class VitPoseImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_affine_transform=True,
|
||||
size=None,
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
):
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_affine_transform = do_affine_transform
|
||||
self.size = size
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_affine_transform": self.do_affine_transform,
|
||||
"size": self.size,
|
||||
"do_rescale": self.do_rescale,
|
||||
"rescale_factor": self.rescale_factor,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VitPoseImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = VitPoseImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = VitPoseImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_affine_transform"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size={"height": 42, "width": 42}
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
|
||||
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
|
||||
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
|
||||
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
|
||||
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
|
||||
encoded_images = image_processing(image_inputs[0], boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
|
||||
encoded_images = image_processing(image_inputs, boxes=boxes, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size * 2, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Test that can process images which have an arbitrary number of channels
|
||||
# Initialize image_processing
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
# Test not batched input
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]]
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
boxes=boxes,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (len(boxes[0]), *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
boxes = [[[0, 0, 1, 1], [0.5, 0.5, 0.5, 0.5]]] * self.image_processor_tester.batch_size
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
boxes=boxes,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size * len(boxes[0]), *expected_output_image_shape),
|
||||
)
|
332
tests/models/vitpose/test_modeling_vitpose.py
Normal file
332
tests/models/vitpose/test_modeling_vitpose.py
Normal file
@ -0,0 +1,332 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch VitPose model."""
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import VitPoseBackboneConfig, VitPoseConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VitPoseForPoseEstimation
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VitPoseImageProcessor
|
||||
|
||||
|
||||
class VitPoseModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=[16 * 8, 12 * 8],
|
||||
patch_size=[8, 8],
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=2,
|
||||
scale_factor=4,
|
||||
out_indices=[-1],
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scale_factor = scale_factor
|
||||
self.out_indices = out_indices
|
||||
self.scope = scope
|
||||
|
||||
# in VitPose, the seq length equals the number of patches
|
||||
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
|
||||
self.seq_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return VitPoseConfig(
|
||||
backbone_config=self.get_backbone_config(),
|
||||
)
|
||||
|
||||
def get_backbone_config(self):
|
||||
return VitPoseBackboneConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_act=self.hidden_act,
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_for_pose_estimation(self, config, pixel_values, labels):
|
||||
model = VitPoseForPoseEstimation(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
expected_height = (self.image_size[0] // self.patch_size[0]) * self.scale_factor
|
||||
expected_width = (self.image_size[1] // self.patch_size[1]) * self.scale_factor
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.heatmaps.shape, (self.batch_size, self.num_labels, expected_height, expected_width)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitPoseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as VitPose does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (VitPoseForPoseEstimation,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=VitPoseConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
@unittest.skip(reason="VitPose does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support input and output embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPose does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_for_pose_estimation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "usyd-community/vitpose-base-simple"
|
||||
model = VitPoseForPoseEstimation.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of people in house
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VitPoseModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return (
|
||||
VitPoseImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_pose_estimation(self):
|
||||
image_processor = self.default_image_processor
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
image = prepare_img()
|
||||
boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
|
||||
|
||||
inputs = image_processor(images=image, boxes=boxes, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
heatmaps = outputs.heatmaps
|
||||
|
||||
assert heatmaps.shape == (2, 17, 64, 48)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
|
||||
|
||||
expected_bbox = torch.tensor([391.9900, 190.0800, 391.1575, 189.3034])
|
||||
expected_keypoints = torch.tensor(
|
||||
[
|
||||
[3.9813e02, 1.8184e02],
|
||||
[3.9828e02, 1.7981e02],
|
||||
[3.9596e02, 1.7948e02],
|
||||
]
|
||||
)
|
||||
expected_scores = torch.tensor([8.7529e-01, 8.4315e-01, 9.2678e-01])
|
||||
|
||||
self.assertEqual(len(pose_results), 2)
|
||||
self.assertTrue(torch.allclose(pose_results[1]["bbox"].cpu(), expected_bbox, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(pose_results[1]["keypoints"][:3].cpu(), expected_keypoints, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(pose_results[1]["scores"][:3].cpu(), expected_scores, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_batched_inference(self):
|
||||
image_processor = self.default_image_processor
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
image = prepare_img()
|
||||
boxes = [
|
||||
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
|
||||
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
|
||||
]
|
||||
|
||||
inputs = image_processor(images=[image, image], boxes=boxes, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
heatmaps = outputs.heatmaps
|
||||
|
||||
assert heatmaps.shape == (4, 17, 64, 48)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
[9.9330e-06, 9.9330e-06, 9.9330e-06],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)
|
||||
print(pose_results)
|
||||
|
||||
expected_bbox = torch.tensor([391.9900, 190.0800, 391.1575, 189.3034])
|
||||
expected_keypoints = torch.tensor(
|
||||
[
|
||||
[3.9813e02, 1.8184e02],
|
||||
[3.9828e02, 1.7981e02],
|
||||
[3.9596e02, 1.7948e02],
|
||||
]
|
||||
)
|
||||
expected_scores = torch.tensor([8.7529e-01, 8.4315e-01, 9.2678e-01])
|
||||
|
||||
self.assertEqual(len(pose_results), 2)
|
||||
self.assertEqual(len(pose_results[0]), 2)
|
||||
self.assertTrue(torch.allclose(pose_results[0][1]["bbox"].cpu(), expected_bbox, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(pose_results[0][1]["keypoints"][:3].cpu(), expected_keypoints, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(pose_results[0][1]["scores"][:3].cpu(), expected_scores, atol=1e-4))
|
0
tests/models/vitpose_backbone/__init__.py
Normal file
0
tests/models/vitpose_backbone/__init__.py
Normal file
199
tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py
Normal file
199
tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py
Normal file
@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch VitPose backbone model."""
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import VitPoseBackboneConfig
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import VitPoseBackbone
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
pass
|
||||
|
||||
|
||||
class VitPoseBackboneModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=[16 * 8, 12 * 8],
|
||||
patch_size=[8, 8],
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=2,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
|
||||
# in VitPoseBackbone, the seq length equals the number of patches
|
||||
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
|
||||
self.seq_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return VitPoseBackboneConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.num_labels,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as VitPoseBackbone does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseBackboneModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=VitPoseBackboneConfig, has_text_modality=False, hidden_size=37
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not output a loss")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="VitPoseBackbone does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
all_model_classes = (VitPoseBackbone,) if is_torch_available() else ()
|
||||
config_class = VitPoseBackboneConfig
|
||||
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseBackboneModelTester(self)
|
@ -1092,6 +1092,7 @@ MODELS_NOT_IN_README = [
|
||||
"CLIPVisionModel",
|
||||
"SiglipVisionModel",
|
||||
"ChineseCLIPVisionModel",
|
||||
"VitPoseBackbone",
|
||||
]
|
||||
|
||||
# Template for new entries to add in the main README when we have missing models.
|
||||
|
@ -330,6 +330,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"SiglipVisionModel",
|
||||
"SiglipTextModel",
|
||||
"ChameleonVQVAE", # no autoclass for VQ-VAE models
|
||||
"VitPoseForPoseEstimation",
|
||||
"CLIPTextModel",
|
||||
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
|
||||
]
|
||||
@ -993,6 +994,8 @@ UNDOCUMENTED_OBJECTS = [
|
||||
"logging", # External module
|
||||
"requires_backends", # Internal function
|
||||
"AltRobertaModel", # Internal module
|
||||
"VitPoseBackbone", # Internal module
|
||||
"VitPoseBackboneConfig", # Internal module
|
||||
]
|
||||
|
||||
# This list should be empty. Objects in it should get their own doc page.
|
||||
|
Loading…
Reference in New Issue
Block a user