Replace 'decord' with 'av' in VideoClassificationPipeline (#29747)

* replace the 'decord' with 'av' in VideoClassificationPipeline

* fix the check of backend in VideoClassificationPipeline

* adjust the order of imports

* format 'video_classification.py'

* format 'video_classification.py' with ruff

---------

Co-authored-by: wanqiancheng <13541261013@163.com>
This commit is contained in:
yunxiangtang 2024-03-26 18:12:24 +08:00 committed by GitHub
parent b5a6d6eeab
commit b32bf85b58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 54 additions and 9 deletions

View File

@ -1083,6 +1083,7 @@ _import_structure = {
"add_end_docstrings",
"add_start_docstrings",
"is_apex_available",
"is_av_available",
"is_bitsandbytes_available",
"is_datasets_available",
"is_decord_available",
@ -5951,6 +5952,7 @@ if TYPE_CHECKING:
add_end_docstrings,
add_start_docstrings,
is_apex_available,
is_av_available,
is_bitsandbytes_available,
is_datasets_available,
is_decord_available,

View File

@ -3,13 +3,19 @@ from typing import List, Union
import requests
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
from ..utils import (
add_end_docstrings,
is_av_available,
is_torch_available,
logging,
requires_backends,
)
from .base import Pipeline, build_pipeline_init_args
if is_decord_available():
if is_av_available():
import av
import numpy as np
from decord import VideoReader
if is_torch_available():
@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "decord")
requires_backends(self, "av")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
@ -90,14 +96,13 @@ class VideoClassificationPipeline(Pipeline):
if video.startswith("http://") or video.startswith("https://"):
video = BytesIO(requests.get(video).content)
videoreader = VideoReader(video)
videoreader.seek(0)
container = av.open(video)
start_idx = 0
end_idx = num_frames * frame_sampling_rate - 1
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
video = videoreader.get_batch(indices).asnumpy()
video = read_video_pyav(container, indices)
video = list(video)
model_inputs = self.image_processor(video, return_tensors=self.framework)
@ -120,3 +125,16 @@ class VideoClassificationPipeline(Pipeline):
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
def read_video_pyav(container, indices):
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])

View File

@ -57,6 +57,7 @@ from .utils import (
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_cv2_available,
@ -1010,6 +1011,13 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
def require_av(test_case):
"""
Decorator marking a test that requires av
"""
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.

View File

@ -109,6 +109,7 @@ from .import_utils import (
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,

View File

@ -94,6 +94,7 @@ FSDP_MIN_VERSION = "1.12.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
@ -656,6 +657,10 @@ def is_aqlm_available():
return _aqlm_available
def is_av_available():
return _av_available
def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
@ -1012,6 +1017,16 @@ def is_mlx_available():
return _mlx_available
# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
```
pip install av
```
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
@ -1336,6 +1351,7 @@ jinja2`. Please note that you may need to restart your runtime after installatio
BACKENDS_MAPPING = OrderedDict(
[
("av", (is_av_available, AV_IMPORT_ERROR)),
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),

View File

@ -21,7 +21,7 @@ from transformers.pipelines import VideoClassificationPipeline, pipeline
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_decord,
require_av,
require_tf,
require_torch,
require_torch_or_tf,
@ -34,7 +34,7 @@ from .test_pipelines_common import ANY
@is_pipeline_test
@require_torch_or_tf
@require_vision
@require_decord
@require_av
class VideoClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING