mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Replace 'decord' with 'av' in VideoClassificationPipeline (#29747)
* replace the 'decord' with 'av' in VideoClassificationPipeline * fix the check of backend in VideoClassificationPipeline * adjust the order of imports * format 'video_classification.py' * format 'video_classification.py' with ruff --------- Co-authored-by: wanqiancheng <13541261013@163.com>
This commit is contained in:
parent
b5a6d6eeab
commit
b32bf85b58
@ -1083,6 +1083,7 @@ _import_structure = {
|
|||||||
"add_end_docstrings",
|
"add_end_docstrings",
|
||||||
"add_start_docstrings",
|
"add_start_docstrings",
|
||||||
"is_apex_available",
|
"is_apex_available",
|
||||||
|
"is_av_available",
|
||||||
"is_bitsandbytes_available",
|
"is_bitsandbytes_available",
|
||||||
"is_datasets_available",
|
"is_datasets_available",
|
||||||
"is_decord_available",
|
"is_decord_available",
|
||||||
@ -5951,6 +5952,7 @@ if TYPE_CHECKING:
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
|
is_av_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_decord_available,
|
is_decord_available,
|
||||||
|
@ -3,13 +3,19 @@ from typing import List, Union
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
|
from ..utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_av_available,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
from .base import Pipeline, build_pipeline_init_args
|
from .base import Pipeline, build_pipeline_init_args
|
||||||
|
|
||||||
|
|
||||||
if is_decord_available():
|
if is_av_available():
|
||||||
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from decord import VideoReader
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
requires_backends(self, "decord")
|
requires_backends(self, "av")
|
||||||
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
|
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
|
||||||
|
|
||||||
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
|
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
|
||||||
@ -90,14 +96,13 @@ class VideoClassificationPipeline(Pipeline):
|
|||||||
if video.startswith("http://") or video.startswith("https://"):
|
if video.startswith("http://") or video.startswith("https://"):
|
||||||
video = BytesIO(requests.get(video).content)
|
video = BytesIO(requests.get(video).content)
|
||||||
|
|
||||||
videoreader = VideoReader(video)
|
container = av.open(video)
|
||||||
videoreader.seek(0)
|
|
||||||
|
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
end_idx = num_frames * frame_sampling_rate - 1
|
end_idx = num_frames * frame_sampling_rate - 1
|
||||||
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
|
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64)
|
||||||
|
|
||||||
video = videoreader.get_batch(indices).asnumpy()
|
video = read_video_pyav(container, indices)
|
||||||
video = list(video)
|
video = list(video)
|
||||||
|
|
||||||
model_inputs = self.image_processor(video, return_tensors=self.framework)
|
model_inputs = self.image_processor(video, return_tensors=self.framework)
|
||||||
@ -120,3 +125,16 @@ class VideoClassificationPipeline(Pipeline):
|
|||||||
scores = scores.tolist()
|
scores = scores.tolist()
|
||||||
ids = ids.tolist()
|
ids = ids.tolist()
|
||||||
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_pyav(container, indices):
|
||||||
|
frames = []
|
||||||
|
container.seek(0)
|
||||||
|
start_index = indices[0]
|
||||||
|
end_index = indices[-1]
|
||||||
|
for i, frame in enumerate(container.decode(video=0)):
|
||||||
|
if i > end_index:
|
||||||
|
break
|
||||||
|
if i >= start_index and i in indices:
|
||||||
|
frames.append(frame)
|
||||||
|
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||||
|
@ -57,6 +57,7 @@ from .utils import (
|
|||||||
is_aqlm_available,
|
is_aqlm_available,
|
||||||
is_auto_awq_available,
|
is_auto_awq_available,
|
||||||
is_auto_gptq_available,
|
is_auto_gptq_available,
|
||||||
|
is_av_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_bs4_available,
|
is_bs4_available,
|
||||||
is_cv2_available,
|
is_cv2_available,
|
||||||
@ -1010,6 +1011,13 @@ def require_aqlm(test_case):
|
|||||||
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
|
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_av(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires av
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_bitsandbytes(test_case):
|
def require_bitsandbytes(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
|
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
|
||||||
|
@ -109,6 +109,7 @@ from .import_utils import (
|
|||||||
is_aqlm_available,
|
is_aqlm_available,
|
||||||
is_auto_awq_available,
|
is_auto_awq_available,
|
||||||
is_auto_gptq_available,
|
is_auto_gptq_available,
|
||||||
|
is_av_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_bs4_available,
|
is_bs4_available,
|
||||||
is_coloredlogs_available,
|
is_coloredlogs_available,
|
||||||
|
@ -94,6 +94,7 @@ FSDP_MIN_VERSION = "1.12.0"
|
|||||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||||
_apex_available = _is_package_available("apex")
|
_apex_available = _is_package_available("apex")
|
||||||
_aqlm_available = _is_package_available("aqlm")
|
_aqlm_available = _is_package_available("aqlm")
|
||||||
|
_av_available = importlib.util.find_spec("av") is not None
|
||||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
_galore_torch_available = _is_package_available("galore_torch")
|
_galore_torch_available = _is_package_available("galore_torch")
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
@ -656,6 +657,10 @@ def is_aqlm_available():
|
|||||||
return _aqlm_available
|
return _aqlm_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_av_available():
|
||||||
|
return _av_available
|
||||||
|
|
||||||
|
|
||||||
def is_ninja_available():
|
def is_ninja_available():
|
||||||
r"""
|
r"""
|
||||||
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
||||||
@ -1012,6 +1017,16 @@ def is_mlx_available():
|
|||||||
return _mlx_available
|
return _mlx_available
|
||||||
|
|
||||||
|
|
||||||
|
# docstyle-ignore
|
||||||
|
AV_IMPORT_ERROR = """
|
||||||
|
{0} requires the PyAv library but it was not found in your environment. You can install it with:
|
||||||
|
```
|
||||||
|
pip install av
|
||||||
|
```
|
||||||
|
Please note that you may need to restart your runtime after installation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
CV2_IMPORT_ERROR = """
|
CV2_IMPORT_ERROR = """
|
||||||
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
||||||
@ -1336,6 +1351,7 @@ jinja2`. Please note that you may need to restart your runtime after installatio
|
|||||||
|
|
||||||
BACKENDS_MAPPING = OrderedDict(
|
BACKENDS_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("av", (is_av_available, AV_IMPORT_ERROR)),
|
||||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||||
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
|
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
|
||||||
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
|
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
|
||||||
|
@ -21,7 +21,7 @@ from transformers.pipelines import VideoClassificationPipeline, pipeline
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_decord,
|
require_av,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_or_tf,
|
require_torch_or_tf,
|
||||||
@ -34,7 +34,7 @@ from .test_pipelines_common import ANY
|
|||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@require_torch_or_tf
|
@require_torch_or_tf
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_decord
|
@require_av
|
||||||
class VideoClassificationPipelineTests(unittest.TestCase):
|
class VideoClassificationPipelineTests(unittest.TestCase):
|
||||||
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user