mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[docs] ViTPose (#38630)
* vitpose * fix? * fix? * feedback * fix * feedback * feedback * update sample image
This commit is contained in:
parent
2b4a12b5bf
commit
df12d87d18
@ -10,52 +10,39 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ViTPose
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# ViTPose
|
||||
|
||||
The ViTPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://huggingface.co/papers/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. ViTPose employs a standard, non-hierarchical [Vision Transformer](vit) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark. The model was further improved in [ViTPose++: Vision Transformer for Generic Body Pose Estimation](https://huggingface.co/papers/2212.04246) where the authors employ
|
||||
a mixture-of-experts (MoE) module in the ViT backbone along with pre-training on more data, which further enhances the performance.
|
||||
[ViTPose](https://huggingface.co/papers/2204.12484) is a vision transformer-based model for keypoint (pose) estimation. It uses a simple, non-hierarchical [ViT](./vit) backbone and a lightweight decoder head. This architecture simplifies model design, takes advantage of transformer scalability, and can be adapted to different training strategies.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Although no specific domain knowledge is considered in the design, plain vision transformers have shown excellent performance in visual recognition tasks. However, little effort has been made to reveal the potential of such simple structures for pose estimation tasks. In this paper, we show the surprisingly good capabilities of plain vision transformers for pose estimation from various aspects, namely simplicity in model structure, scalability in model size, flexibility in training paradigm, and transferability of knowledge between models, through a simple baseline model called ViTPose. Specifically, ViTPose employs plain and non-hierarchical vision transformers as backbones to extract features for a given person instance and a lightweight decoder for pose estimation. It can be scaled up from 100M to 1B parameters by taking the advantages of the scalable model capacity and high parallelism of transformers, setting a new Pareto front between throughput and performance. Besides, ViTPose is very flexible regarding the attention type, input resolution, pre-training and finetuning strategy, as well as dealing with multiple pose tasks. We also empirically demonstrate that the knowledge of large ViTPose models can be easily transferred to small ones via a simple knowledge token. Experimental results show that our basic ViTPose model outperforms representative methods on the challenging MS COCO Keypoint Detection benchmark, while the largest model sets a new state-of-the-art.*
|
||||
[ViTPose++](https://huggingface.co/papers/2212.04246) improves on ViTPose by incorporating a mixture-of-experts (MoE) module in the backbone and using more diverse pretraining data.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> ViTPose architecture. Taken from the <a href="https://huggingface.co/papers/2204.12484">original paper.</a> </small>
|
||||
You can find all ViTPose and ViTPose++ checkpoints under the [ViTPose collection](https://huggingface.co/collections/usyd-community/vitpose-677fcfd0a0b2b5c8f79c4335).
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr) and [sangbumchoi](https://github.com/SangbumChoi).
|
||||
The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPose).
|
||||
|
||||
## Usage Tips
|
||||
|
||||
ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints for each of them.
|
||||
The example below demonstrates pose estimation with the [`VitPoseForPoseEstimation`] class.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
import supervision as sv
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
|
||||
url = "https://www.fcbarcelona.com/fcbarcelona/photo/2021/01/31/3c55a19f-dfc1-4451-885e-afd14e890a11/mini_2021-01-31-BARCELONA-ATHLETIC-BILBAOI-30.JPG"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Stage 1. Detect humans on the image
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
# You can choose any detector of your choice
|
||||
# Detect humans in the image
|
||||
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
||||
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
|
||||
|
||||
@ -67,7 +54,7 @@ with torch.no_grad():
|
||||
results = person_image_processor.post_process_object_detection(
|
||||
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
||||
)
|
||||
result = results[0] # take first image results
|
||||
result = results[0]
|
||||
|
||||
# Human label refers 0 index in COCO dataset
|
||||
person_boxes = result["boxes"][result["labels"] == 0]
|
||||
@ -77,10 +64,7 @@ person_boxes = person_boxes.cpu().numpy()
|
||||
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
||||
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Stage 2. Detect keypoints for each person found
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
# Detect keypoints for each person found
|
||||
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
|
||||
|
||||
@ -90,54 +74,7 @@ with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
|
||||
image_pose_result = pose_results[0] # results for first image
|
||||
```
|
||||
|
||||
### ViTPose++ models
|
||||
|
||||
The best [checkpoints](https://huggingface.co/collections/usyd-community/vitpose-677fcfd0a0b2b5c8f79c4335) are those of the [ViTPose++ paper](https://huggingface.co/papers/2212.04246). ViTPose++ models employ a so-called [Mixture-of-Experts (MoE)](https://huggingface.co/blog/moe) architecture for the ViT backbone, resulting in better performance.
|
||||
|
||||
The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
|
||||
An overview of the various dataset indices is provided below:
|
||||
|
||||
- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
|
||||
- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
|
||||
- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
|
||||
- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
|
||||
- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
|
||||
- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset
|
||||
|
||||
Pass the `dataset_index` argument in the forward of the model to indicate which experts to use for each example in the batch. Example usage is shown below:
|
||||
|
||||
```python
|
||||
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-base")
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-base", device=device)
|
||||
|
||||
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
|
||||
|
||||
dataset_index = torch.tensor([0], device=device) # must be a tensor of shape (batch_size,)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, dataset_index=dataset_index)
|
||||
```
|
||||
|
||||
The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
|
||||
An overview of the various dataset indices is provided below:
|
||||
|
||||
- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
|
||||
- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
|
||||
- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
|
||||
- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
|
||||
- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
|
||||
- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset
|
||||
|
||||
|
||||
### Visualization
|
||||
|
||||
To visualize the various keypoints, one can either leverage the `supervision` [library](https://github.com/roboflow/supervision (requires `pip install supervision`):
|
||||
|
||||
```python
|
||||
import supervision as sv
|
||||
image_pose_result = pose_results[0]
|
||||
|
||||
xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
|
||||
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()
|
||||
@ -162,119 +99,192 @@ annotated_frame = vertex_annotator.annotate(
|
||||
scene=annotated_frame,
|
||||
key_points=key_points
|
||||
)
|
||||
annotated_frame
|
||||
```
|
||||
|
||||
Alternatively, one can also visualize the keypoints using [OpenCV](https://opencv.org/) (requires `pip install opencv-python`):
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose.png"/>
|
||||
</div>
|
||||
|
||||
```python
|
||||
import math
|
||||
import cv2
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
|
||||
if pose_keypoint_color is not None:
|
||||
assert len(pose_keypoint_color) == len(keypoints)
|
||||
for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
|
||||
x_coord, y_coord = int(kpt[0]), int(kpt[1])
|
||||
if kpt_score > keypoint_score_threshold:
|
||||
color = tuple(int(c) for c in pose_keypoint_color[kid])
|
||||
if show_keypoint_weight:
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
transparency = max(0, min(1, kpt_score))
|
||||
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
||||
else:
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
|
||||
height, width, _ = image.shape
|
||||
if keypoint_edges is not None and link_colors is not None:
|
||||
assert len(link_colors) == len(keypoint_edges)
|
||||
for sk_id, sk in enumerate(keypoint_edges):
|
||||
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
|
||||
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
|
||||
if (
|
||||
x1 > 0
|
||||
and x1 < width
|
||||
and y1 > 0
|
||||
and y1 < height
|
||||
and x2 > 0
|
||||
and x2 < width
|
||||
and y2 > 0
|
||||
and y2 < height
|
||||
and score1 > keypoint_score_threshold
|
||||
and score2 > keypoint_score_threshold
|
||||
):
|
||||
color = tuple(int(c) for c in link_colors[sk_id])
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation, TorchAoConfig
|
||||
|
||||
url = "https://www.fcbarcelona.com/fcbarcelona/photo/2021/01/31/3c55a19f-dfc1-4451-885e-afd14e890a11/mini_2021-01-31-BARCELONA-ATHLETIC-BILBAOI-30.JPG"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
||||
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
|
||||
|
||||
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = person_model(**inputs)
|
||||
|
||||
results = person_image_processor.post_process_object_detection(
|
||||
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
person_boxes = result["boxes"][result["labels"] == 0]
|
||||
person_boxes = person_boxes.cpu().numpy()
|
||||
|
||||
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
||||
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
|
||||
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-huge")
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-huge", device_map=device, quantization_config=quantization_config)
|
||||
|
||||
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
|
||||
image_pose_result = pose_results[0]
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Use [`AutoProcessor`] to automatically prepare bounding box and image inputs.
|
||||
- ViTPose is a top-down pose estimator. It uses a object detector to detect individuals first before keypoint prediction.
|
||||
- ViTPose++ has 6 different MoE expert heads (COCO validation `0`, AiC `1`, MPII `2`, AP-10K `3`, APT-36K `4`, COCO-WholeBody `5`) which supports 6 different datasets. Pass a specific value corresponding to the dataset to the `dataset_index` to indicate which expert to use.
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, VitPoseForPoseEstimation
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-base")
|
||||
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-base", device=device)
|
||||
|
||||
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
|
||||
dataset_index = torch.tensor([0], device=device) # must be a tensor of shape (batch_size,)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, dataset_index=dataset_index)
|
||||
```
|
||||
|
||||
- [OpenCV](https://opencv.org/) is an alternative option for visualizing the estimated pose.
|
||||
|
||||
```py
|
||||
# pip install opencv-python
|
||||
import math
|
||||
import cv2
|
||||
|
||||
def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
|
||||
if pose_keypoint_color is not None:
|
||||
assert len(pose_keypoint_color) == len(keypoints)
|
||||
for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
|
||||
x_coord, y_coord = int(kpt[0]), int(kpt[1])
|
||||
if kpt_score > keypoint_score_threshold:
|
||||
color = tuple(int(c) for c in pose_keypoint_color[kid])
|
||||
if show_keypoint_weight:
|
||||
X = (x1, x2)
|
||||
Y = (y1, y2)
|
||||
mean_x = np.mean(X)
|
||||
mean_y = np.mean(Y)
|
||||
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
|
||||
)
|
||||
cv2.fillConvexPoly(image, polygon, color)
|
||||
transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
transparency = max(0, min(1, kpt_score))
|
||||
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
||||
else:
|
||||
cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
|
||||
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
||||
|
||||
def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
|
||||
height, width, _ = image.shape
|
||||
if keypoint_edges is not None and link_colors is not None:
|
||||
assert len(link_colors) == len(keypoint_edges)
|
||||
for sk_id, sk in enumerate(keypoint_edges):
|
||||
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
|
||||
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
|
||||
if (
|
||||
x1 > 0
|
||||
and x1 < width
|
||||
and y1 > 0
|
||||
and y1 < height
|
||||
and x2 > 0
|
||||
and x2 < width
|
||||
and y2 > 0
|
||||
and y2 < height
|
||||
and score1 > keypoint_score_threshold
|
||||
and score2 > keypoint_score_threshold
|
||||
):
|
||||
color = tuple(int(c) for c in link_colors[sk_id])
|
||||
if show_keypoint_weight:
|
||||
X = (x1, x2)
|
||||
Y = (y1, y2)
|
||||
mean_x = np.mean(X)
|
||||
mean_y = np.mean(Y)
|
||||
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
|
||||
)
|
||||
cv2.fillConvexPoly(image, polygon, color)
|
||||
transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
|
||||
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
||||
else:
|
||||
cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
|
||||
|
||||
# Note: keypoint_edges and color palette are dataset-specific
|
||||
keypoint_edges = model.config.edges
|
||||
# Note: keypoint_edges and color palette are dataset-specific
|
||||
keypoint_edges = model.config.edges
|
||||
|
||||
palette = np.array(
|
||||
[
|
||||
[255, 128, 0],
|
||||
[255, 153, 51],
|
||||
[255, 178, 102],
|
||||
[230, 230, 0],
|
||||
[255, 153, 255],
|
||||
[153, 204, 255],
|
||||
[255, 102, 255],
|
||||
[255, 51, 255],
|
||||
[102, 178, 255],
|
||||
[51, 153, 255],
|
||||
[255, 153, 153],
|
||||
[255, 102, 102],
|
||||
[255, 51, 51],
|
||||
[153, 255, 153],
|
||||
[102, 255, 102],
|
||||
[51, 255, 51],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[255, 255, 255],
|
||||
]
|
||||
)
|
||||
palette = np.array(
|
||||
[
|
||||
[255, 128, 0],
|
||||
[255, 153, 51],
|
||||
[255, 178, 102],
|
||||
[230, 230, 0],
|
||||
[255, 153, 255],
|
||||
[153, 204, 255],
|
||||
[255, 102, 255],
|
||||
[255, 51, 255],
|
||||
[102, 178, 255],
|
||||
[51, 153, 255],
|
||||
[255, 153, 153],
|
||||
[255, 102, 102],
|
||||
[255, 51, 51],
|
||||
[153, 255, 153],
|
||||
[102, 255, 102],
|
||||
[51, 255, 51],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[255, 255, 255],
|
||||
]
|
||||
)
|
||||
|
||||
link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
|
||||
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
|
||||
link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
|
||||
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
numpy_image = np.array(image)
|
||||
numpy_image = np.array(image)
|
||||
|
||||
for pose_result in image_pose_result:
|
||||
scores = np.array(pose_result["scores"])
|
||||
keypoints = np.array(pose_result["keypoints"])
|
||||
for pose_result in image_pose_result:
|
||||
scores = np.array(pose_result["scores"])
|
||||
keypoints = np.array(pose_result["keypoints"])
|
||||
|
||||
# draw each point on image
|
||||
draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
|
||||
# draw each point on image
|
||||
draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
|
||||
|
||||
# draw links
|
||||
draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
|
||||
# draw links
|
||||
draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
|
||||
|
||||
pose_image = Image.fromarray(numpy_image)
|
||||
pose_image
|
||||
```
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>
|
||||
pose_image = Image.fromarray(numpy_image)
|
||||
pose_image
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTPose. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
Refer to resources below to learn more about using ViTPose.
|
||||
|
||||
- A demo of ViTPose on images and video can be found [here](https://huggingface.co/spaces/hysts/ViTPose-transformers).
|
||||
- A notebook illustrating inference and visualization can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTPose/Inference_with_ViTPose_for_human_pose_estimation.ipynb).
|
||||
- This [notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTPose/Inference_with_ViTPose_for_body_pose_estimation.ipynb) demonstrates inference and visualization.
|
||||
- This [Space](https://huggingface.co/spaces/hysts/ViTPose-transformers) demonstrates ViTPose on images and video.
|
||||
|
||||
## VitPoseImageProcessor
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user