Use PyAV instead of Decord in examples (#21572)

* Use PyAV instead of Decord

* Get frame indices

* Fix number of frames

* Update src/transformers/models/videomae/image_processing_videomae.py

* Fix up

* Fix copies

* Update timesformer doctests

* Update docstrings
This commit is contained in:
amyeroberts 2023-03-02 12:30:38 +00:00 committed by GitHub
parent c256bc6d10
commit 3412f5979d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 224 additions and 59 deletions

View File

@ -51,7 +51,8 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install --no-cache-dir bitsandbytes
RUN python3 -m pip install --no-cache-dir decord
# For video model testing
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
# For `dinat` model
RUN python3 -m pip install --no-cache-dir natten -f https://shi-labs.com/natten/wheels/$CUDA/

View File

@ -98,6 +98,7 @@ if stale_egg_info.exists():
_deps = [
"Pillow",
"accelerate>=0.10.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"black~=23.1",
"codecarbon==1.2.0",
@ -289,7 +290,7 @@ extras["timm"] = deps_list("timm")
extras["torch-vision"] = deps_list("torchvision") + extras["vision"]
extras["natten"] = deps_list("natten")
extras["codecarbon"] = deps_list("codecarbon")
extras["video"] = deps_list("decord")
extras["video"] = deps_list("decord", "av")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (

View File

@ -4,6 +4,7 @@
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.10.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"black": "black~=23.1",
"codecarbon": "codecarbon==1.2.0",

View File

@ -1425,11 +1425,11 @@ class GitForCausalLM(GitPreTrainedModel):
Video captioning example:
```python
>>> from transformers import AutoProcessor, AutoModelForCausalLM
>>> from PIL import Image
>>> import av
>>> import numpy as np
>>> from PIL import Image
>>> from huggingface_hub import hf_hub_download
>>> from decord import VideoReader, cpu
>>> from transformers import AutoProcessor, AutoModelForCausalLM
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
@ -1438,6 +1438,27 @@ class GitForCausalLM(GitPreTrainedModel):
>>> np.random.seed(45)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -1447,24 +1468,20 @@ class GitForCausalLM(GitPreTrainedModel):
... return indices
>>> def sample_frames(file_path, num_frames):
... videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
... videoreader.seek(0)
... indices = sample_frame_indices(clip_len=num_frames, frame_sample_rate=4, seg_len=len(videoreader))
... frames = videoreader.get_batch(indices).asnumpy()
... return list(frames)
>>> # load video
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> container = av.open(file_path)
>>> # sample frames
>>> num_frames = model.config.num_image_with_embedding
>>> frames = sample_frames(file_path, num_frames)
>>> indices = sample_frame_indices(
... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
... )
>>> frames = read_video_pyav(container, indices)
>>> pixel_values = processor(images=frames, return_tensors="pt").pixel_values
>>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)

View File

@ -570,12 +570,35 @@ class TimesformerModel(TimesformerPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import numpy as np
>>> from transformers import AutoFeatureExtractor, TimesformerModel
>>> from transformers import AutoImageProcessor, TimesformerModel
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
@ -590,24 +613,23 @@ class TimesformerModel(TimesformerPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 8 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
>>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
>>> model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
>>> # prepare video for the model
>>> inputs = feature_extractor(list(video), return_tensors="pt")
>>> inputs = image_processor(list(video), return_tensors="pt")
>>> # forward pass
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 1568, 768]
[1, 1569, 768]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -676,16 +698,37 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import torch
>>> import numpy as np
>>> from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
>>> from transformers import AutoImageProcessor, TimesformerForVideoClassification
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -699,17 +742,16 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 8 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
>>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
>>> model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400")
>>> inputs = feature_extractor(list(video), return_tensors="pt")
>>> inputs = image_processor(list(video), return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)

View File

@ -576,12 +576,35 @@ class VideoMAEModel(VideoMAEPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import numpy as np
>>> from transformers import AutoImageProcessor, VideoMAEModel
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
@ -596,12 +619,11 @@ class VideoMAEModel(VideoMAEPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 16 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
>>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
@ -944,7 +966,7 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import torch
>>> import numpy as np
@ -954,6 +976,27 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -967,12 +1010,11 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 16 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
>>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")

View File

@ -1065,7 +1065,7 @@ class XCLIPVisionModel(XCLIPPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import torch
>>> import numpy as np
@ -1075,6 +1075,27 @@ class XCLIPVisionModel(XCLIPPreTrainedModel):
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -1088,12 +1109,11 @@ class XCLIPVisionModel(XCLIPPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 16 frames
>>> vr.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr))
>>> video = vr.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
>>> model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
@ -1363,7 +1383,7 @@ class XCLIPModel(XCLIPPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import torch
>>> import numpy as np
@ -1373,6 +1393,27 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -1386,12 +1427,11 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 16 frames
>>> vr.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr))
>>> video = vr.get_batch(indices).asnumpy()
>>> # sample 8 frames
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
>>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
@ -1451,7 +1491,7 @@ class XCLIPModel(XCLIPPreTrainedModel):
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import av
>>> import torch
>>> import numpy as np
@ -1461,6 +1501,27 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> np.random.seed(0)
>>> def read_video_pyav(container, indices):
... '''
... Decode the video with PyAV decoder.
... Args:
... container (`av.container.input.InputContainer`): PyAV container.
... indices (`List[int]`): List of frame indices to decode.
... Returns:
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
... '''
... frames = []
... container.seek(0)
... start_index = indices[0]
... end_index = indices[-1]
... for i, frame in enumerate(container.decode(video=0)):
... if i > end_index:
... break
... if i >= start_index and i in indices:
... frames.append(frame)
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
@ -1474,12 +1535,11 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> container = av.open(file_path)
>>> # sample 16 frames
>>> vr.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr))
>>> video = vr.get_batch(indices).asnumpy()
>>> # sample 8 frames
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container, indices)
>>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
>>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")

View File

@ -184,6 +184,7 @@ src/transformers/models/swinv2/configuration_swinv2.py
src/transformers/models/table_transformer/modeling_table_transformer.py
src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
src/transformers/models/timesformer/modeling_timesformer.py
src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
src/transformers/models/transfo_xl/configuration_transfo_xl.py
src/transformers/models/trocr/configuration_trocr.py