mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix video batching to videollava (#32139)
--------- Co-authored-by: Merve Noyan <mervenoyan@Merve-MacBook-Pro.local>
This commit is contained in:
parent
a5b226ce98
commit
9ced33ca7f
@ -55,8 +55,11 @@ def make_batched_videos(videos) -> List[VideoInput]:
|
||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||
return videos
|
||||
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4:
|
||||
return [list(video) for video in videos]
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||
if isinstance(videos[0], PIL.Image.Image):
|
||||
return [videos]
|
||||
elif len(videos[0].shape) == 4:
|
||||
return [list(video) for video in videos]
|
||||
|
||||
elif is_valid_image(videos) and len(videos.shape) == 4:
|
||||
return [list(videos)]
|
||||
|
@ -97,8 +97,7 @@ class VideoLlavaImageProcessingTester(unittest.TestCase):
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, torchify=False):
|
||||
numpify = not torchify
|
||||
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
@ -108,15 +107,19 @@ class VideoLlavaImageProcessingTester(unittest.TestCase):
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * 8)
|
||||
|
||||
return videos
|
||||
|
||||
@ -197,7 +200,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
@ -211,6 +214,24 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pil_videos(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# the inputs come in list of lists batched format
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user