mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add video classification pipeline (#20151)
* 🚧 wip video classification pipeline * 🚧 wip - add is_decord_available check * 🐛 add missing import * ✅ add tests * 🔧 add decord to setup extras * 🚧 add is_decord_available * ✨ add video-classification pipeline * 📝 add video classification pipe to docs * 🐛 add missing VideoClassificationPipeline import * 📌 add decord install in test runner * ✅ fix url inputs to video-classification pipeline * ✨ updates from review * 📝 add video cls pipeline to docs * 📝 add docstring * 🔥 remove unused import * 🔥 remove some code * 📝 docfix
This commit is contained in:
parent
c56ebbbea6
commit
9e56aff58a
@ -188,7 +188,7 @@ pipelines_torch_job = CircleCIJob(
|
||||
install_steps=[
|
||||
"sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng",
|
||||
"pip install --upgrade pip",
|
||||
"pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]",
|
||||
"pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm,video]",
|
||||
],
|
||||
pytest_options={"rA": None},
|
||||
tests_to_run="tests/pipelines/"
|
||||
|
@ -341,6 +341,12 @@ Pipelines available for computer vision tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### VideoClassificationPipeline
|
||||
|
||||
[[autodoc]] VideoClassificationPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ZeroShotImageClassificationPipeline
|
||||
|
||||
[[autodoc]] ZeroShotImageClassificationPipeline
|
||||
|
4
setup.py
4
setup.py
@ -103,6 +103,7 @@ _deps = [
|
||||
"cookiecutter==1.7.3",
|
||||
"dataclasses",
|
||||
"datasets!=2.5.0",
|
||||
"decord==0.6.0",
|
||||
"deepspeed>=0.6.5",
|
||||
"dill<0.3.5",
|
||||
"evaluate>=0.2.0",
|
||||
@ -286,7 +287,7 @@ extras["vision"] = deps_list("Pillow")
|
||||
extras["timm"] = deps_list("timm")
|
||||
extras["natten"] = deps_list("natten")
|
||||
extras["codecarbon"] = deps_list("codecarbon")
|
||||
|
||||
extras["video"] = deps_list("decord")
|
||||
|
||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||
extras["testing"] = (
|
||||
@ -332,6 +333,7 @@ extras["all"] = (
|
||||
+ extras["timm"]
|
||||
+ extras["codecarbon"]
|
||||
+ extras["accelerate"]
|
||||
+ extras["video"]
|
||||
)
|
||||
|
||||
# Might need to add doc-builder and some specific deps in the future
|
||||
|
@ -489,6 +489,7 @@ _import_structure = {
|
||||
"TextGenerationPipeline",
|
||||
"TokenClassificationPipeline",
|
||||
"TranslationPipeline",
|
||||
"VideoClassificationPipeline",
|
||||
"VisualQuestionAnsweringPipeline",
|
||||
"ZeroShotClassificationPipeline",
|
||||
"ZeroShotImageClassificationPipeline",
|
||||
@ -534,6 +535,7 @@ _import_structure = {
|
||||
"add_start_docstrings",
|
||||
"is_apex_available",
|
||||
"is_datasets_available",
|
||||
"is_decord_available",
|
||||
"is_faiss_available",
|
||||
"is_flax_available",
|
||||
"is_keras_nlp_available",
|
||||
@ -3724,6 +3726,7 @@ if TYPE_CHECKING:
|
||||
TextGenerationPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
VideoClassificationPipeline,
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
@ -3774,6 +3777,7 @@ if TYPE_CHECKING:
|
||||
add_start_docstrings,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_decord_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_keras_nlp_available,
|
||||
|
@ -9,6 +9,7 @@ deps = {
|
||||
"cookiecutter": "cookiecutter==1.7.3",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets!=2.5.0",
|
||||
"decord": "decord==0.6.0",
|
||||
"deepspeed": "deepspeed>=0.6.5",
|
||||
"dill": "dill<0.3.5",
|
||||
"evaluate": "evaluate>=0.2.0",
|
||||
|
@ -79,6 +79,7 @@ from .token_classification import (
|
||||
TokenClassificationArgumentHandler,
|
||||
TokenClassificationPipeline,
|
||||
)
|
||||
from .video_classification import VideoClassificationPipeline
|
||||
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||
@ -133,6 +134,7 @@ if is_torch_available():
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
@ -361,6 +363,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},
|
||||
"type": "image",
|
||||
},
|
||||
"video-classification": {
|
||||
"impl": VideoClassificationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
|
||||
"type": "video",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
@ -373,7 +382,7 @@ MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig
|
||||
for task, values in SUPPORTED_TASKS.items():
|
||||
if values["type"] == "text":
|
||||
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
||||
elif values["type"] in {"audio", "image"}:
|
||||
elif values["type"] in {"audio", "image", "video"}:
|
||||
NO_TOKENIZER_TASKS.add(task)
|
||||
elif values["type"] != "multimodal":
|
||||
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
||||
|
124
src/transformers/pipelines/video_classification.py
Normal file
124
src/transformers/pipelines/video_classification.py
Normal file
@ -0,0 +1,124 @@
|
||||
from io import BytesIO
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_decord_available():
|
||||
import numpy as np
|
||||
|
||||
from decord import VideoReader
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class VideoClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Video classification pipeline using any `AutoModelForVideoClassification`. This pipeline predicts the class of a
|
||||
video.
|
||||
|
||||
This video classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"video-classification"`.
|
||||
|
||||
See the list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=video-classification).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "decord")
|
||||
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
|
||||
preprocess_params = {}
|
||||
if frame_sampling_rate is not None:
|
||||
preprocess_params["frame_sampling_rate"] = frame_sampling_rate
|
||||
if num_frames is not None:
|
||||
preprocess_params["num_frames"] = num_frames
|
||||
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, videos: Union[str, List[str]], **kwargs):
|
||||
"""
|
||||
Assign labels to the video(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
videos (`str`, `List[str]`):
|
||||
The pipeline handles three types of videos:
|
||||
|
||||
- A string containing a http link pointing to a video
|
||||
- A string containing a local path to a video
|
||||
|
||||
The pipeline accepts either a single video or a batch of videos, which must then be passed as a string.
|
||||
Videos in a batch must all be in the same format: all as http links or all as local paths.
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
num_frames (`int`, *optional*, defaults to `self.model.config.num_frames`):
|
||||
The number of frames sampled from the video to run the classification on. If not provided, will default
|
||||
to the number of frames specified in the model configuration.
|
||||
frame_sampling_rate (`int`, *optional*, defaults to 1):
|
||||
The sampling rate used to select frames from the video. If not provided, will default to 1, i.e. every
|
||||
frame will be used.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single video, will return a
|
||||
dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to
|
||||
the videos.
|
||||
|
||||
The dictionaries contain the following keys:
|
||||
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
return super().__call__(videos, **kwargs)
|
||||
|
||||
def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
|
||||
|
||||
if num_frames is None:
|
||||
num_frames = self.model.config.num_frames
|
||||
|
||||
if video.startswith("http://") or video.startswith("https://"):
|
||||
video = BytesIO(requests.get(video).content)
|
||||
|
||||
videoreader = VideoReader(video)
|
||||
videoreader.seek(0)
|
||||
|
||||
start_idx = 0
|
||||
end_idx = num_frames * frame_sampling_rate - 1
|
||||
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
|
||||
|
||||
video = videoreader.get_batch(indices).asnumpy()
|
||||
video = list(video)
|
||||
|
||||
model_inputs = self.feature_extractor(video, return_tensors=self.framework)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
if self.framework == "pt":
|
||||
probs = model_outputs.logits.softmax(-1)[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
@ -51,6 +51,7 @@ from .utils import (
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bs4_available,
|
||||
is_decord_available,
|
||||
is_detectron2_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
@ -446,6 +447,13 @@ def require_spacy(test_case):
|
||||
return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
|
||||
|
||||
|
||||
def require_decord(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires decord. These tests are skipped when decord isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
||||
|
@ -104,6 +104,7 @@ from .import_utils import (
|
||||
is_bs4_available,
|
||||
is_coloredlogs_available,
|
||||
is_datasets_available,
|
||||
is_decord_available,
|
||||
is_detectron2_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
|
@ -268,6 +268,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_ccl_available = False
|
||||
|
||||
_decord_availale = importlib.util.find_spec("decord") is not None
|
||||
try:
|
||||
_decord_version = importlib_metadata.version("decord")
|
||||
logger.debug(f"Successfully imported decord version {_decord_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_decord_availale = False
|
||||
|
||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
||||
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||
@ -706,6 +713,10 @@ def is_ccl_available():
|
||||
return _is_ccl_available
|
||||
|
||||
|
||||
def is_decord_available():
|
||||
return _decord_availale
|
||||
|
||||
|
||||
def is_sudachi_available():
|
||||
return importlib.util.find_spec("sudachipy") is not None
|
||||
|
||||
@ -953,6 +964,11 @@ CCL_IMPORT_ERROR = """
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
DECORD_IMPORT_ERROR = """
|
||||
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
|
||||
decord`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@ -982,6 +998,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
||||
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
|
||||
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
96
tests/pipelines/test_pipelines_video_classification.py
Normal file
96
tests/pipelines/test_pipelines_video_classification.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, VideoMAEFeatureExtractor
|
||||
from transformers.pipelines import VideoClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
nested_simplify,
|
||||
require_decord,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_or_tf,
|
||||
require_vision,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
@require_torch_or_tf
|
||||
@require_vision
|
||||
@require_decord
|
||||
class VideoClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
example_video_filepath = hf_hub_download(
|
||||
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
|
||||
)
|
||||
video_classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||
examples = [
|
||||
example_video_filepath,
|
||||
"https://huggingface.co/datasets/nateraw/video-demo/resolve/main/archery.mp4",
|
||||
]
|
||||
return video_classifier, examples
|
||||
|
||||
def run_pipeline_test(self, video_classifier, examples):
|
||||
|
||||
for example in examples:
|
||||
outputs = video_classifier(example)
|
||||
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
small_model = "hf-internal-testing/tiny-random-VideoMAEForVideoClassification"
|
||||
small_feature_extractor = VideoMAEFeatureExtractor(
|
||||
size=dict(shortest_edge=10), crop_size=dict(height=10, width=10)
|
||||
)
|
||||
video_classifier = pipeline(
|
||||
"video-classification", model=small_model, feature_extractor=small_feature_extractor, frame_sampling_rate=4
|
||||
)
|
||||
|
||||
video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
|
||||
outputs = video_classifier(video_file_path, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
|
||||
)
|
||||
|
||||
outputs = video_classifier(
|
||||
[
|
||||
video_file_path,
|
||||
video_file_path,
|
||||
],
|
||||
top_k=2,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
|
||||
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
pass
|
@ -102,6 +102,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
"AutoModel",
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user