From 3412f5979d08a2db4a7575c463435405843162a7 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 2 Mar 2023 12:30:38 +0000 Subject: [PATCH] Use PyAV instead of Decord in examples (#21572) * Use PyAV instead of Decord * Get frame indices * Fix number of frames * Update src/transformers/models/videomae/image_processing_videomae.py * Fix up * Fix copies * Update timesformer doctests * Update docstrings --- docker/transformers-all-latest-gpu/Dockerfile | 3 +- setup.py | 3 +- src/transformers/dependency_versions_table.py | 1 + src/transformers/models/git/modeling_git.py | 43 ++++++--- .../timesformer/modeling_timesformer.py | 76 +++++++++++---- .../models/videomae/modeling_videomae.py | 62 ++++++++++-- .../models/x_clip/modeling_x_clip.py | 94 +++++++++++++++---- utils/documentation_tests.txt | 1 + 8 files changed, 224 insertions(+), 59 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 95e127ef6dd..46d8b127b8e 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -51,7 +51,8 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc # Add bitsandbytes for mixed int8 testing RUN python3 -m pip install --no-cache-dir bitsandbytes -RUN python3 -m pip install --no-cache-dir decord +# For video model testing +RUN python3 -m pip install --no-cache-dir decord av==9.2.0 # For `dinat` model RUN python3 -m pip install --no-cache-dir natten -f https://shi-labs.com/natten/wheels/$CUDA/ diff --git a/setup.py b/setup.py index 17cc262a2c4..1be88908c93 100644 --- a/setup.py +++ b/setup.py @@ -98,6 +98,7 @@ if stale_egg_info.exists(): _deps = [ "Pillow", "accelerate>=0.10.0", + "av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream. "beautifulsoup4", "black~=23.1", "codecarbon==1.2.0", @@ -289,7 +290,7 @@ extras["timm"] = deps_list("timm") extras["torch-vision"] = deps_list("torchvision") + extras["vision"] extras["natten"] = deps_list("natten") extras["codecarbon"] = deps_list("codecarbon") -extras["video"] = deps_list("decord") +extras["video"] = deps_list("decord", "av") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b7dfc633768..79f9118ae84 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -4,6 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.10.0", + "av": "av==9.2.0", "beautifulsoup4": "beautifulsoup4", "black": "black~=23.1", "codecarbon": "codecarbon==1.2.0", diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 934077f871f..f55a687c4bc 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1425,11 +1425,11 @@ class GitForCausalLM(GitPreTrainedModel): Video captioning example: ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM - >>> from PIL import Image + >>> import av >>> import numpy as np + >>> from PIL import Image >>> from huggingface_hub import hf_hub_download - >>> from decord import VideoReader, cpu + >>> from transformers import AutoProcessor, AutoModelForCausalLM >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex") >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex") @@ -1438,6 +1438,27 @@ class GitForCausalLM(GitPreTrainedModel): >>> np.random.seed(45) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -1447,24 +1468,20 @@ class GitForCausalLM(GitPreTrainedModel): ... return indices - >>> def sample_frames(file_path, num_frames): - ... videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0)) - ... videoreader.seek(0) - ... indices = sample_frame_indices(clip_len=num_frames, frame_sample_rate=4, seg_len=len(videoreader)) - ... frames = videoreader.get_batch(indices).asnumpy() - ... return list(frames) - - >>> # load video >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) + >>> container = av.open(file_path) >>> # sample frames >>> num_frames = model.config.num_image_with_embedding - >>> frames = sample_frames(file_path, num_frames) + >>> indices = sample_frame_indices( + ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames + ... ) + >>> frames = read_video_pyav(container, indices) - >>> pixel_values = processor(images=frames, return_tensors="pt").pixel_values + >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50) diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index b5a525127b3..9f886b6ece5 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -570,12 +570,35 @@ class TimesformerModel(TimesformerPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import numpy as np - >>> from transformers import AutoFeatureExtractor, TimesformerModel + >>> from transformers import AutoImageProcessor, TimesformerModel >>> from huggingface_hub import hf_hub_download + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) @@ -590,24 +613,23 @@ class TimesformerModel(TimesformerPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) >>> # sample 8 frames - >>> videoreader.seek(0) - >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("MCG-NJU/videomae-base") + >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base") >>> model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400") >>> # prepare video for the model - >>> inputs = feature_extractor(list(video), return_tensors="pt") + >>> inputs = image_processor(list(video), return_tensors="pt") >>> # forward pass >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) - [1, 1568, 768] + [1, 1569, 768] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -676,16 +698,37 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import torch >>> import numpy as np - >>> from transformers import AutoFeatureExtractor, TimesformerForVideoClassification + >>> from transformers import AutoImageProcessor, TimesformerForVideoClassification >>> from huggingface_hub import hf_hub_download >>> np.random.seed(0) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -699,17 +742,16 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) >>> # sample 8 frames - >>> videoreader.seek(0) - >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics") + >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics") >>> model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400") - >>> inputs = feature_extractor(list(video), return_tensors="pt") + >>> inputs = image_processor(list(video), return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index ee166317909..0d1f12b8730 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -576,12 +576,35 @@ class VideoMAEModel(VideoMAEPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import numpy as np >>> from transformers import AutoImageProcessor, VideoMAEModel >>> from huggingface_hub import hf_hub_download + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) @@ -596,12 +619,11 @@ class VideoMAEModel(VideoMAEPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) >>> # sample 16 frames - >>> videoreader.seek(0) - >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base") >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base") @@ -944,7 +966,7 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import torch >>> import numpy as np @@ -954,6 +976,27 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel): >>> np.random.seed(0) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -967,12 +1010,11 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) >>> # sample 16 frames - >>> videoreader.seek(0) - >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics") >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics") diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 69289ced40a..337d919cf4a 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -1065,7 +1065,7 @@ class XCLIPVisionModel(XCLIPPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import torch >>> import numpy as np @@ -1075,6 +1075,27 @@ class XCLIPVisionModel(XCLIPPreTrainedModel): >>> np.random.seed(0) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -1088,12 +1109,11 @@ class XCLIPVisionModel(XCLIPPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) >>> # sample 16 frames - >>> vr.seek(0) - >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr)) - >>> video = vr.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") >>> model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32") @@ -1363,7 +1383,7 @@ class XCLIPModel(XCLIPPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import torch >>> import numpy as np @@ -1373,6 +1393,27 @@ class XCLIPModel(XCLIPPreTrainedModel): >>> np.random.seed(0) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -1386,12 +1427,11 @@ class XCLIPModel(XCLIPPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) - >>> # sample 16 frames - >>> vr.seek(0) - >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr)) - >>> video = vr.get_batch(indices).asnumpy() + >>> # sample 8 frames + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32") @@ -1451,7 +1491,7 @@ class XCLIPModel(XCLIPPreTrainedModel): Examples: ```python - >>> from decord import VideoReader, cpu + >>> import av >>> import torch >>> import numpy as np @@ -1461,6 +1501,27 @@ class XCLIPModel(XCLIPPreTrainedModel): >>> np.random.seed(0) + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) @@ -1474,12 +1535,11 @@ class XCLIPModel(XCLIPPreTrainedModel): >>> file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) - >>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0)) + >>> container = av.open(file_path) - >>> # sample 16 frames - >>> vr.seek(0) - >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=len(vr)) - >>> video = vr.get_batch(indices).asnumpy() + >>> # sample 8 frames + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32") diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 210799b93a4..8b622bf778d 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -184,6 +184,7 @@ src/transformers/models/swinv2/configuration_swinv2.py src/transformers/models/table_transformer/modeling_table_transformer.py src/transformers/models/time_series_transformer/configuration_time_series_transformer.py src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +src/transformers/models/timesformer/modeling_timesformer.py src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py src/transformers/models/transfo_xl/configuration_transfo_xl.py src/transformers/models/trocr/configuration_trocr.py