From 9e56aff58a742b48fc8edea8d28d5b80330efbcc Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Thu, 8 Dec 2022 16:22:43 -0500 Subject: [PATCH] Add video classification pipeline (#20151) * :construction: wip video classification pipeline * :construction: wip - add is_decord_available check * :bug: add missing import * :white_check_mark: add tests * :wrench: add decord to setup extras * :construction: add is_decord_available * :sparkles: add video-classification pipeline * :memo: add video classification pipe to docs * :bug: add missing VideoClassificationPipeline import * :pushpin: add decord install in test runner * :white_check_mark: fix url inputs to video-classification pipeline * :sparkles: updates from review * :memo: add video cls pipeline to docs * :memo: add docstring * :fire: remove unused import * :fire: remove some code * :memo: docfix --- .circleci/create_circleci_config.py | 2 +- docs/source/en/main_classes/pipelines.mdx | 6 + setup.py | 4 +- src/transformers/__init__.py | 4 + src/transformers/dependency_versions_table.py | 1 + src/transformers/pipelines/__init__.py | 11 +- .../pipelines/video_classification.py | 124 ++++++++++++++++++ src/transformers/testing_utils.py | 8 ++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 17 +++ .../test_pipelines_video_classification.py | 96 ++++++++++++++ utils/update_metadata.py | 1 + 12 files changed, 272 insertions(+), 3 deletions(-) create mode 100644 src/transformers/pipelines/video_classification.py create mode 100644 tests/pipelines/test_pipelines_video_classification.py diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 8c00789fbb3..8cfbd37d15f 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -188,7 +188,7 @@ pipelines_torch_job = CircleCIJob( install_steps=[ "sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng", "pip install --upgrade pip", - "pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]", + "pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm,video]", ], pytest_options={"rA": None}, tests_to_run="tests/pipelines/" diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index f6c63a983fc..e5ee3902028 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -341,6 +341,12 @@ Pipelines available for computer vision tasks include the following. - __call__ - all +### VideoClassificationPipeline + +[[autodoc]] VideoClassificationPipeline + - __call__ + - all + ### ZeroShotImageClassificationPipeline [[autodoc]] ZeroShotImageClassificationPipeline diff --git a/setup.py b/setup.py index 7c0bd03f11a..b089b561f1c 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,7 @@ _deps = [ "cookiecutter==1.7.3", "dataclasses", "datasets!=2.5.0", + "decord==0.6.0", "deepspeed>=0.6.5", "dill<0.3.5", "evaluate>=0.2.0", @@ -286,7 +287,7 @@ extras["vision"] = deps_list("Pillow") extras["timm"] = deps_list("timm") extras["natten"] = deps_list("natten") extras["codecarbon"] = deps_list("codecarbon") - +extras["video"] = deps_list("decord") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( @@ -332,6 +333,7 @@ extras["all"] = ( + extras["timm"] + extras["codecarbon"] + extras["accelerate"] + + extras["video"] ) # Might need to add doc-builder and some specific deps in the future diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dce1d73e9fa..cd997e8d057 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -489,6 +489,7 @@ _import_structure = { "TextGenerationPipeline", "TokenClassificationPipeline", "TranslationPipeline", + "VideoClassificationPipeline", "VisualQuestionAnsweringPipeline", "ZeroShotClassificationPipeline", "ZeroShotImageClassificationPipeline", @@ -534,6 +535,7 @@ _import_structure = { "add_start_docstrings", "is_apex_available", "is_datasets_available", + "is_decord_available", "is_faiss_available", "is_flax_available", "is_keras_nlp_available", @@ -3724,6 +3726,7 @@ if TYPE_CHECKING: TextGenerationPipeline, TokenClassificationPipeline, TranslationPipeline, + VideoClassificationPipeline, VisualQuestionAnsweringPipeline, ZeroShotClassificationPipeline, ZeroShotImageClassificationPipeline, @@ -3774,6 +3777,7 @@ if TYPE_CHECKING: add_start_docstrings, is_apex_available, is_datasets_available, + is_decord_available, is_faiss_available, is_flax_available, is_keras_nlp_available, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 3b8ca512a9b..75261f414ec 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -9,6 +9,7 @@ deps = { "cookiecutter": "cookiecutter==1.7.3", "dataclasses": "dataclasses", "datasets": "datasets!=2.5.0", + "decord": "decord==0.6.0", "deepspeed": "deepspeed>=0.6.5", "dill": "dill<0.3.5", "evaluate": "evaluate>=0.2.0", diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 7d160b61a8a..685e8e16e5f 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -79,6 +79,7 @@ from .token_classification import ( TokenClassificationArgumentHandler, TokenClassificationPipeline, ) +from .video_classification import VideoClassificationPipeline from .visual_question_answering import VisualQuestionAnsweringPipeline from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline from .zero_shot_image_classification import ZeroShotImageClassificationPipeline @@ -133,6 +134,7 @@ if is_torch_available(): AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, AutoModelForTokenClassification, + AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, AutoModelForZeroShotObjectDetection, @@ -361,6 +363,13 @@ SUPPORTED_TASKS = { "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}}, "type": "image", }, + "video-classification": { + "impl": VideoClassificationPipeline, + "tf": (), + "pt": (AutoModelForVideoClassification,) if is_torch_available() else (), + "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}}, + "type": "video", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() @@ -373,7 +382,7 @@ MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig for task, values in SUPPORTED_TASKS.items(): if values["type"] == "text": NO_FEATURE_EXTRACTOR_TASKS.add(task) - elif values["type"] in {"audio", "image"}: + elif values["type"] in {"audio", "image", "video"}: NO_TOKENIZER_TASKS.add(task) elif values["type"] != "multimodal": raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") diff --git a/src/transformers/pipelines/video_classification.py b/src/transformers/pipelines/video_classification.py new file mode 100644 index 00000000000..8d53fb851b5 --- /dev/null +++ b/src/transformers/pipelines/video_classification.py @@ -0,0 +1,124 @@ +from io import BytesIO +from typing import List, Union + +import requests + +from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends +from .base import PIPELINE_INIT_ARGS, Pipeline + + +if is_decord_available(): + import numpy as np + + from decord import VideoReader + + +if is_torch_available(): + from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class VideoClassificationPipeline(Pipeline): + """ + Video classification pipeline using any `AutoModelForVideoClassification`. This pipeline predicts the class of a + video. + + This video classification pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"video-classification"`. + + See the list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=video-classification). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + requires_backends(self, "decord") + self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING) + + def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None): + preprocess_params = {} + if frame_sampling_rate is not None: + preprocess_params["frame_sampling_rate"] = frame_sampling_rate + if num_frames is not None: + preprocess_params["num_frames"] = num_frames + + postprocess_params = {} + if top_k is not None: + postprocess_params["top_k"] = top_k + return preprocess_params, {}, postprocess_params + + def __call__(self, videos: Union[str, List[str]], **kwargs): + """ + Assign labels to the video(s) passed as inputs. + + Args: + videos (`str`, `List[str]`): + The pipeline handles three types of videos: + + - A string containing a http link pointing to a video + - A string containing a local path to a video + + The pipeline accepts either a single video or a batch of videos, which must then be passed as a string. + Videos in a batch must all be in the same format: all as http links or all as local paths. + top_k (`int`, *optional*, defaults to 5): + The number of top labels that will be returned by the pipeline. If the provided number is higher than + the number of labels available in the model configuration, it will default to the number of labels. + num_frames (`int`, *optional*, defaults to `self.model.config.num_frames`): + The number of frames sampled from the video to run the classification on. If not provided, will default + to the number of frames specified in the model configuration. + frame_sampling_rate (`int`, *optional*, defaults to 1): + The sampling rate used to select frames from the video. If not provided, will default to 1, i.e. every + frame will be used. + + Return: + A dictionary or a list of dictionaries containing result. If the input is a single video, will return a + dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to + the videos. + + The dictionaries contain the following keys: + + - **label** (`str`) -- The label identified by the model. + - **score** (`int`) -- The score attributed by the model for that label. + """ + return super().__call__(videos, **kwargs) + + def preprocess(self, video, num_frames=None, frame_sampling_rate=1): + + if num_frames is None: + num_frames = self.model.config.num_frames + + if video.startswith("http://") or video.startswith("https://"): + video = BytesIO(requests.get(video).content) + + videoreader = VideoReader(video) + videoreader.seek(0) + + start_idx = 0 + end_idx = num_frames * frame_sampling_rate - 1 + indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64) + + video = videoreader.get_batch(indices).asnumpy() + video = list(video) + + model_inputs = self.feature_extractor(video, return_tensors=self.framework) + return model_inputs + + def _forward(self, model_inputs): + model_outputs = self.model(**model_inputs) + return model_outputs + + def postprocess(self, model_outputs, top_k=5): + if top_k > self.model.config.num_labels: + top_k = self.model.config.num_labels + + if self.framework == "pt": + probs = model_outputs.logits.softmax(-1)[0] + scores, ids = probs.topk(top_k) + else: + raise ValueError(f"Unsupported framework: {self.framework}") + + scores = scores.tolist() + ids = ids.tolist() + return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 5ef6bfd36aa..31760557aa9 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -51,6 +51,7 @@ from .utils import ( is_apex_available, is_bitsandbytes_available, is_bs4_available, + is_decord_available, is_detectron2_available, is_faiss_available, is_flax_available, @@ -446,6 +447,13 @@ def require_spacy(test_case): return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) +def require_decord(test_case): + """ + Decorator marking a test that requires decord. These tests are skipped when decord isn't installed. + """ + return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case) + + def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index c0cf73e6fa4..525149417f2 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -104,6 +104,7 @@ from .import_utils import ( is_bs4_available, is_coloredlogs_available, is_datasets_available, + is_decord_available, is_detectron2_available, is_faiss_available, is_flax_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bce7f769159..1e53ab46ba5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -268,6 +268,13 @@ try: except importlib_metadata.PackageNotFoundError: _is_ccl_available = False +_decord_availale = importlib.util.find_spec("decord") is not None +try: + _decord_version = importlib_metadata.version("decord") + logger.debug(f"Successfully imported decord version {_decord_version}") +except importlib_metadata.PackageNotFoundError: + _decord_availale = False + # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. TORCH_FX_REQUIRED_VERSION = version.parse("1.10") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") @@ -706,6 +713,10 @@ def is_ccl_available(): return _is_ccl_available +def is_decord_available(): + return _decord_availale + + def is_sudachi_available(): return importlib.util.find_spec("sudachipy") is not None @@ -953,6 +964,11 @@ CCL_IMPORT_ERROR = """ Please note that you may need to restart your runtime after installation. """ +DECORD_IMPORT_ERROR = """ +{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install +decord`. Please note that you may need to restart your runtime after installation. +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -982,6 +998,7 @@ BACKENDS_MAPPING = OrderedDict( ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)), + ("decord", (is_decord_available, DECORD_IMPORT_ERROR)), ] ) diff --git a/tests/pipelines/test_pipelines_video_classification.py b/tests/pipelines/test_pipelines_video_classification.py new file mode 100644 index 00000000000..25ddcfaf2d3 --- /dev/null +++ b/tests/pipelines/test_pipelines_video_classification.py @@ -0,0 +1,96 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from huggingface_hub import hf_hub_download +from transformers import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, VideoMAEFeatureExtractor +from transformers.pipelines import VideoClassificationPipeline, pipeline +from transformers.testing_utils import ( + nested_simplify, + require_decord, + require_tf, + require_torch, + require_torch_or_tf, + require_vision, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +@require_torch_or_tf +@require_vision +@require_decord +class VideoClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING + + def get_test_pipeline(self, model, tokenizer, feature_extractor): + example_video_filepath = hf_hub_download( + repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset" + ) + video_classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2) + examples = [ + example_video_filepath, + "https://huggingface.co/datasets/nateraw/video-demo/resolve/main/archery.mp4", + ] + return video_classifier, examples + + def run_pipeline_test(self, video_classifier, examples): + + for example in examples: + outputs = video_classifier(example) + + self.assertEqual( + outputs, + [ + {"score": ANY(float), "label": ANY(str)}, + {"score": ANY(float), "label": ANY(str)}, + ], + ) + + @require_torch + def test_small_model_pt(self): + small_model = "hf-internal-testing/tiny-random-VideoMAEForVideoClassification" + small_feature_extractor = VideoMAEFeatureExtractor( + size=dict(shortest_edge=10), crop_size=dict(height=10, width=10) + ) + video_classifier = pipeline( + "video-classification", model=small_model, feature_extractor=small_feature_extractor, frame_sampling_rate=4 + ) + + video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset") + outputs = video_classifier(video_file_path, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}], + ) + + outputs = video_classifier( + [ + video_file_path, + video_file_path, + ], + top_k=2, + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}], + [{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}], + ], + ) + + @require_tf + def test_small_model_tf(self): + pass diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 5e7169c2558..e624759ebe2 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -102,6 +102,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ "AutoModel", ), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), + ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), ]