Add docstrings and fix VIVIT examples (#25628)

* fix docstrings and examples

* docstring update

* add missing whitespace
This commit is contained in:
Tigran Khachatryan 2023-08-26 22:08:47 +03:00 committed by GitHub
parent 960807f62e
commit 686c68f64c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 108 additions and 6 deletions

View File

@ -200,6 +200,17 @@ def prepare_video():
np.random.seed(0)
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
"""
Sample a given number of frame indices from the video.
Args:
clip_len (`int`): Total number of frames to sample.
frame_sample_rate (`int`): Sample every n-th frame.
seg_len (`int`): Maximum allowed index of sample's last frame.
Returns:
indices (`List[int]`): List of sampled frame indices
"""
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len

View File

@ -1465,6 +1465,15 @@ class GitForCausalLM(GitPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len

View File

@ -601,6 +601,15 @@ class TimesformerModel(TimesformerPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -730,6 +739,15 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len

View File

@ -612,6 +612,15 @@ class VideoMAEModel(VideoMAEPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -1008,6 +1017,15 @@ class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len

View File

@ -532,6 +532,15 @@ class VivitModel(VivitPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -547,8 +556,8 @@ class VivitModel(VivitPreTrainedModel):
>>> container = av.open(file_path)
>>> # sample 32 frames
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container=container, indices=indices)
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
>>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
@ -639,8 +648,9 @@ class VivitForVideoClassification(VivitPreTrainedModel):
```python
>>> import av
>>> import numpy as np
>>> import torch
>>> from transformers import VivitImageProcessor, VivitModel
>>> from transformers import VivitImageProcessor, VivitForVideoClassification
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
@ -668,6 +678,15 @@ class VivitForVideoClassification(VivitPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -683,8 +702,8 @@ class VivitForVideoClassification(VivitPreTrainedModel):
>>> container = av.open(file_path)
>>> # sample 32 frames
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container=container, indices=indices)
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
>>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
@ -698,7 +717,7 @@ class VivitForVideoClassification(VivitPreTrainedModel):
>>> # model predicts one of the 400 Kinetics-400 classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
eating spaghetti
LABEL_116
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

View File

@ -1105,6 +1105,15 @@ class XCLIPVisionModel(XCLIPPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -1423,6 +1432,15 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
@ -1531,6 +1549,15 @@ class XCLIPModel(XCLIPPreTrainedModel):
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len