mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
e27d230ddd
commit
e40f301f1f
@ -180,13 +180,15 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
||||
|
||||
def process_vision(self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None):
|
||||
def process_vision(
|
||||
self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None, processor=None
|
||||
):
|
||||
if text is not None:
|
||||
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
||||
|
||||
n_images_in_images = [len(sublist) for sublist in images]
|
||||
image_inputs = self.image_processor(
|
||||
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs["images_kwargs"]
|
||||
image_inputs = processor(
|
||||
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs
|
||||
)
|
||||
|
||||
if text is None:
|
||||
@ -309,9 +311,10 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
text, vision_inputs = self.process_vision(
|
||||
text,
|
||||
images,
|
||||
output_kwargs,
|
||||
output_kwargs["images_kwargs"],
|
||||
do_image_splitting=self.do_image_splitting,
|
||||
image_processor_size=self.image_size,
|
||||
processor=self.image_processor,
|
||||
)
|
||||
inputs.update(vision_inputs)
|
||||
elif videos is not None:
|
||||
@ -319,9 +322,10 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
text, vision_inputs = self.process_vision(
|
||||
text,
|
||||
videos,
|
||||
output_kwargs,
|
||||
output_kwargs["videos_kwargs"],
|
||||
do_image_splitting=self.do_image_splitting,
|
||||
image_processor_size=self.video_size,
|
||||
processor=self.video_processor,
|
||||
)
|
||||
inputs.update(vision_inputs)
|
||||
|
||||
|
@ -22,7 +22,7 @@ import requests
|
||||
|
||||
from transformers import SmolVLMProcessor
|
||||
from transformers.models.auto.processing_auto import AutoProcessor
|
||||
from transformers.testing_utils import is_flaky, require_av, require_torch, require_vision
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -118,10 +118,6 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
@is_flaky # fails 15 out of 100, FIXME @raushan
|
||||
def test_structured_kwargs_nested_from_dict_video(self):
|
||||
super().test_structured_kwargs_nested_from_dict_video()
|
||||
|
||||
def test_process_interleaved_images_prompts_no_image_splitting(self):
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", padding_side="left")
|
||||
@ -467,6 +463,31 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 300)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched_video(self):
|
||||
if "video_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=2, modality="video")
|
||||
video_input = self.prepare_video_inputs(batch_size=2)
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
videos=video_input,
|
||||
return_tensors="pt",
|
||||
do_rescale=True,
|
||||
rescale_factor=-1,
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_text_only_inference(self):
|
||||
|
Loading…
Reference in New Issue
Block a user