Add video classification pipeline (#20151)

* 🚧 wip video classification pipeline

* 🚧 wip - add is_decord_available check

* 🐛 add missing import

*  add tests

* 🔧 add decord to setup extras

* 🚧 add is_decord_available

*  add video-classification pipeline

* 📝 add video classification pipe to docs

* 🐛 add missing VideoClassificationPipeline import

* 📌 add decord install in test runner

*  fix url inputs to video-classification pipeline

*  updates from review

* 📝 add video cls pipeline to docs

* 📝 add docstring

* 🔥 remove unused import

* 🔥 remove some code

* 📝 docfix
This commit is contained in:
Nathan Raw 2022-12-08 16:22:43 -05:00 committed by GitHub
parent c56ebbbea6
commit 9e56aff58a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 272 additions and 3 deletions

View File

@ -188,7 +188,7 @@ pipelines_torch_job = CircleCIJob(
install_steps=[ install_steps=[
"sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng", "sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng",
"pip install --upgrade pip", "pip install --upgrade pip",
"pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]", "pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm,video]",
], ],
pytest_options={"rA": None}, pytest_options={"rA": None},
tests_to_run="tests/pipelines/" tests_to_run="tests/pipelines/"

View File

@ -341,6 +341,12 @@ Pipelines available for computer vision tasks include the following.
- __call__ - __call__
- all - all
### VideoClassificationPipeline
[[autodoc]] VideoClassificationPipeline
- __call__
- all
### ZeroShotImageClassificationPipeline ### ZeroShotImageClassificationPipeline
[[autodoc]] ZeroShotImageClassificationPipeline [[autodoc]] ZeroShotImageClassificationPipeline

View File

@ -103,6 +103,7 @@ _deps = [
"cookiecutter==1.7.3", "cookiecutter==1.7.3",
"dataclasses", "dataclasses",
"datasets!=2.5.0", "datasets!=2.5.0",
"decord==0.6.0",
"deepspeed>=0.6.5", "deepspeed>=0.6.5",
"dill<0.3.5", "dill<0.3.5",
"evaluate>=0.2.0", "evaluate>=0.2.0",
@ -286,7 +287,7 @@ extras["vision"] = deps_list("Pillow")
extras["timm"] = deps_list("timm") extras["timm"] = deps_list("timm")
extras["natten"] = deps_list("natten") extras["natten"] = deps_list("natten")
extras["codecarbon"] = deps_list("codecarbon") extras["codecarbon"] = deps_list("codecarbon")
extras["video"] = deps_list("decord")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = ( extras["testing"] = (
@ -332,6 +333,7 @@ extras["all"] = (
+ extras["timm"] + extras["timm"]
+ extras["codecarbon"] + extras["codecarbon"]
+ extras["accelerate"] + extras["accelerate"]
+ extras["video"]
) )
# Might need to add doc-builder and some specific deps in the future # Might need to add doc-builder and some specific deps in the future

View File

@ -489,6 +489,7 @@ _import_structure = {
"TextGenerationPipeline", "TextGenerationPipeline",
"TokenClassificationPipeline", "TokenClassificationPipeline",
"TranslationPipeline", "TranslationPipeline",
"VideoClassificationPipeline",
"VisualQuestionAnsweringPipeline", "VisualQuestionAnsweringPipeline",
"ZeroShotClassificationPipeline", "ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline", "ZeroShotImageClassificationPipeline",
@ -534,6 +535,7 @@ _import_structure = {
"add_start_docstrings", "add_start_docstrings",
"is_apex_available", "is_apex_available",
"is_datasets_available", "is_datasets_available",
"is_decord_available",
"is_faiss_available", "is_faiss_available",
"is_flax_available", "is_flax_available",
"is_keras_nlp_available", "is_keras_nlp_available",
@ -3724,6 +3726,7 @@ if TYPE_CHECKING:
TextGenerationPipeline, TextGenerationPipeline,
TokenClassificationPipeline, TokenClassificationPipeline,
TranslationPipeline, TranslationPipeline,
VideoClassificationPipeline,
VisualQuestionAnsweringPipeline, VisualQuestionAnsweringPipeline,
ZeroShotClassificationPipeline, ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline, ZeroShotImageClassificationPipeline,
@ -3774,6 +3777,7 @@ if TYPE_CHECKING:
add_start_docstrings, add_start_docstrings,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_decord_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_keras_nlp_available, is_keras_nlp_available,

View File

@ -9,6 +9,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.3", "cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses", "dataclasses": "dataclasses",
"datasets": "datasets!=2.5.0", "datasets": "datasets!=2.5.0",
"decord": "decord==0.6.0",
"deepspeed": "deepspeed>=0.6.5", "deepspeed": "deepspeed>=0.6.5",
"dill": "dill<0.3.5", "dill": "dill<0.3.5",
"evaluate": "evaluate>=0.2.0", "evaluate": "evaluate>=0.2.0",

View File

@ -79,6 +79,7 @@ from .token_classification import (
TokenClassificationArgumentHandler, TokenClassificationArgumentHandler,
TokenClassificationPipeline, TokenClassificationPipeline,
) )
from .video_classification import VideoClassificationPipeline
from .visual_question_answering import VisualQuestionAnsweringPipeline from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
@ -133,6 +134,7 @@ if is_torch_available():
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForVideoClassification,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering, AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotObjectDetection, AutoModelForZeroShotObjectDetection,
@ -361,6 +363,13 @@ SUPPORTED_TASKS = {
"default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}}, "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},
"type": "image", "type": "image",
}, },
"video-classification": {
"impl": VideoClassificationPipeline,
"tf": (),
"pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
"type": "video",
},
} }
NO_FEATURE_EXTRACTOR_TASKS = set() NO_FEATURE_EXTRACTOR_TASKS = set()
@ -373,7 +382,7 @@ MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig
for task, values in SUPPORTED_TASKS.items(): for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text": if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task) NO_FEATURE_EXTRACTOR_TASKS.add(task)
elif values["type"] in {"audio", "image"}: elif values["type"] in {"audio", "image", "video"}:
NO_TOKENIZER_TASKS.add(task) NO_TOKENIZER_TASKS.add(task)
elif values["type"] != "multimodal": elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")

View File

@ -0,0 +1,124 @@
from io import BytesIO
from typing import List, Union
import requests
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_decord_available():
import numpy as np
from decord import VideoReader
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class VideoClassificationPipeline(Pipeline):
"""
Video classification pipeline using any `AutoModelForVideoClassification`. This pipeline predicts the class of a
video.
This video classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"video-classification"`.
See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=video-classification).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "decord")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING)
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
preprocess_params = {}
if frame_sampling_rate is not None:
preprocess_params["frame_sampling_rate"] = frame_sampling_rate
if num_frames is not None:
preprocess_params["num_frames"] = num_frames
postprocess_params = {}
if top_k is not None:
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
def __call__(self, videos: Union[str, List[str]], **kwargs):
"""
Assign labels to the video(s) passed as inputs.
Args:
videos (`str`, `List[str]`):
The pipeline handles three types of videos:
- A string containing a http link pointing to a video
- A string containing a local path to a video
The pipeline accepts either a single video or a batch of videos, which must then be passed as a string.
Videos in a batch must all be in the same format: all as http links or all as local paths.
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
num_frames (`int`, *optional*, defaults to `self.model.config.num_frames`):
The number of frames sampled from the video to run the classification on. If not provided, will default
to the number of frames specified in the model configuration.
frame_sampling_rate (`int`, *optional*, defaults to 1):
The sampling rate used to select frames from the video. If not provided, will default to 1, i.e. every
frame will be used.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single video, will return a
dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to
the videos.
The dictionaries contain the following keys:
- **label** (`str`) -- The label identified by the model.
- **score** (`int`) -- The score attributed by the model for that label.
"""
return super().__call__(videos, **kwargs)
def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
if num_frames is None:
num_frames = self.model.config.num_frames
if video.startswith("http://") or video.startswith("https://"):
video = BytesIO(requests.get(video).content)
videoreader = VideoReader(video)
videoreader.seek(0)
start_idx = 0
end_idx = num_frames * frame_sampling_rate - 1
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
video = videoreader.get_batch(indices).asnumpy()
video = list(video)
model_inputs = self.feature_extractor(video, return_tensors=self.framework)
return model_inputs
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs, top_k=5):
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
if self.framework == "pt":
probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k)
else:
raise ValueError(f"Unsupported framework: {self.framework}")
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

View File

@ -51,6 +51,7 @@ from .utils import (
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_bs4_available, is_bs4_available,
is_decord_available,
is_detectron2_available, is_detectron2_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
@ -446,6 +447,13 @@ def require_spacy(test_case):
return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
def require_decord(test_case):
"""
Decorator marking a test that requires decord. These tests are skipped when decord isn't installed.
"""
return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
def require_torch_multi_gpu(test_case): def require_torch_multi_gpu(test_case):
""" """
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without

View File

@ -104,6 +104,7 @@ from .import_utils import (
is_bs4_available, is_bs4_available,
is_coloredlogs_available, is_coloredlogs_available,
is_datasets_available, is_datasets_available,
is_decord_available,
is_detectron2_available, is_detectron2_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,

View File

@ -268,6 +268,13 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_is_ccl_available = False _is_ccl_available = False
_decord_availale = importlib.util.find_spec("decord") is not None
try:
_decord_version = importlib_metadata.version("decord")
logger.debug(f"Successfully imported decord version {_decord_version}")
except importlib_metadata.PackageNotFoundError:
_decord_availale = False
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10") TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
@ -706,6 +713,10 @@ def is_ccl_available():
return _is_ccl_available return _is_ccl_available
def is_decord_available():
return _decord_availale
def is_sudachi_available(): def is_sudachi_available():
return importlib.util.find_spec("sudachipy") is not None return importlib.util.find_spec("sudachipy") is not None
@ -953,6 +964,11 @@ CCL_IMPORT_ERROR = """
Please note that you may need to restart your runtime after installation. Please note that you may need to restart your runtime after installation.
""" """
DECORD_IMPORT_ERROR = """
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
decord`. Please note that you may need to restart your runtime after installation.
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@ -982,6 +998,7 @@ BACKENDS_MAPPING = OrderedDict(
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)), ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
] ]
) )

View File

@ -0,0 +1,96 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from huggingface_hub import hf_hub_download
from transformers import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, VideoMAEFeatureExtractor
from transformers.pipelines import VideoClassificationPipeline, pipeline
from transformers.testing_utils import (
nested_simplify,
require_decord,
require_tf,
require_torch,
require_torch_or_tf,
require_vision,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@require_torch_or_tf
@require_vision
@require_decord
class VideoClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
example_video_filepath = hf_hub_download(
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
)
video_classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
examples = [
example_video_filepath,
"https://huggingface.co/datasets/nateraw/video-demo/resolve/main/archery.mp4",
]
return video_classifier, examples
def run_pipeline_test(self, video_classifier, examples):
for example in examples:
outputs = video_classifier(example)
self.assertEqual(
outputs,
[
{"score": ANY(float), "label": ANY(str)},
{"score": ANY(float), "label": ANY(str)},
],
)
@require_torch
def test_small_model_pt(self):
small_model = "hf-internal-testing/tiny-random-VideoMAEForVideoClassification"
small_feature_extractor = VideoMAEFeatureExtractor(
size=dict(shortest_edge=10), crop_size=dict(height=10, width=10)
)
video_classifier = pipeline(
"video-classification", model=small_model, feature_extractor=small_feature_extractor, frame_sampling_rate=4
)
video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
outputs = video_classifier(video_file_path, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
)
outputs = video_classifier(
[
video_file_path,
video_file_path,
],
top_k=2,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
],
)
@require_tf
def test_small_model_tf(self):
pass

View File

@ -102,6 +102,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
"AutoModel", "AutoModel",
), ),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
] ]