mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Replace 'decord' with 'av' in VideoClassificationPipeline (#29747)
* replace the 'decord' with 'av' in VideoClassificationPipeline * fix the check of backend in VideoClassificationPipeline * adjust the order of imports * format 'video_classification.py' * format 'video_classification.py' with ruff --------- Co-authored-by: wanqiancheng <13541261013@163.com>
This commit is contained in:
parent
b5a6d6eeab
commit
b32bf85b58
@ -1083,6 +1083,7 @@ _import_structure = {
|
||||
"add_end_docstrings",
|
||||
"add_start_docstrings",
|
||||
"is_apex_available",
|
||||
"is_av_available",
|
||||
"is_bitsandbytes_available",
|
||||
"is_datasets_available",
|
||||
"is_decord_available",
|
||||
@ -5951,6 +5952,7 @@ if TYPE_CHECKING:
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
is_apex_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
is_decord_available,
|
||||
|
@ -3,13 +3,19 @@ from typing import List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_av_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
if is_decord_available():
|
||||
if is_av_available():
|
||||
import av
|
||||
import numpy as np
|
||||
from decord import VideoReader
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "decord")
|
||||
requires_backends(self, "av")
|
||||
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
|
||||
@ -90,14 +96,13 @@ class VideoClassificationPipeline(Pipeline):
|
||||
if video.startswith("http://") or video.startswith("https://"):
|
||||
video = BytesIO(requests.get(video).content)
|
||||
|
||||
videoreader = VideoReader(video)
|
||||
videoreader.seek(0)
|
||||
container = av.open(video)
|
||||
|
||||
start_idx = 0
|
||||
end_idx = num_frames * frame_sampling_rate - 1
|
||||
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
|
||||
|
||||
video = videoreader.get_batch(indices).asnumpy()
|
||||
video = read_video_pyav(container, indices)
|
||||
video = list(video)
|
||||
|
||||
model_inputs = self.image_processor(video, return_tensors=self.framework)
|
||||
@ -120,3 +125,16 @@ class VideoClassificationPipeline(Pipeline):
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||
|
||||
|
||||
def read_video_pyav(container, indices):
|
||||
frames = []
|
||||
container.seek(0)
|
||||
start_index = indices[0]
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= start_index and i in indices:
|
||||
frames.append(frame)
|
||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
@ -57,6 +57,7 @@ from .utils import (
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bs4_available,
|
||||
is_cv2_available,
|
||||
@ -1010,6 +1011,13 @@ def require_aqlm(test_case):
|
||||
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
|
||||
|
||||
|
||||
def require_av(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires av
|
||||
"""
|
||||
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
|
||||
|
@ -109,6 +109,7 @@ from .import_utils import (
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bs4_available,
|
||||
is_coloredlogs_available,
|
||||
|
@ -94,6 +94,7 @@ FSDP_MIN_VERSION = "1.12.0"
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
_aqlm_available = _is_package_available("aqlm")
|
||||
_av_available = importlib.util.find_spec("av") is not None
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
@ -656,6 +657,10 @@ def is_aqlm_available():
|
||||
return _aqlm_available
|
||||
|
||||
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
def is_ninja_available():
|
||||
r"""
|
||||
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
||||
@ -1012,6 +1017,16 @@ def is_mlx_available():
|
||||
return _mlx_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
AV_IMPORT_ERROR = """
|
||||
{0} requires the PyAv library but it was not found in your environment. You can install it with:
|
||||
```
|
||||
pip install av
|
||||
```
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CV2_IMPORT_ERROR = """
|
||||
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
||||
@ -1336,6 +1351,7 @@ jinja2`. Please note that you may need to restart your runtime after installatio
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("av", (is_av_available, AV_IMPORT_ERROR)),
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
|
||||
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
|
||||
|
@ -21,7 +21,7 @@ from transformers.pipelines import VideoClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_decord,
|
||||
require_av,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_or_tf,
|
||||
@ -34,7 +34,7 @@ from .test_pipelines_common import ANY
|
||||
@is_pipeline_test
|
||||
@require_torch_or_tf
|
||||
@require_vision
|
||||
@require_decord
|
||||
@require_av
|
||||
class VideoClassificationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user