mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[pipeline
] Add pool option to image feature extraction pipeline (#28985)
* Add pool option * PR comments - error message and exact outputs check
This commit is contained in:
parent
c47576ca6e
commit
e770f0316d
@ -14,6 +14,8 @@ if is_vision_available():
|
||||
image_processor_kwargs (`dict`, *optional*):
|
||||
Additional dictionary of keyword arguments passed along to the image processor e.g.
|
||||
{"size": {"height": 100, "width": 100}}
|
||||
pool (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
|
||||
""",
|
||||
)
|
||||
class ImageFeatureExtractionPipeline(Pipeline):
|
||||
@ -41,9 +43,14 @@ class ImageFeatureExtractionPipeline(Pipeline):
|
||||
[huggingface.co/models](https://huggingface.co/models).
|
||||
"""
|
||||
|
||||
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
|
||||
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
|
||||
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
|
||||
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}
|
||||
|
||||
postprocess_params = {}
|
||||
if pool is not None:
|
||||
postprocess_params["pool"] = pool
|
||||
if return_tensors is not None:
|
||||
postprocess_params["return_tensors"] = return_tensors
|
||||
|
||||
if "timeout" in kwargs:
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
@ -59,14 +66,25 @@ class ImageFeatureExtractionPipeline(Pipeline):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, return_tensors=False):
|
||||
# [0] is the first available tensor, logits or last_hidden_state.
|
||||
def postprocess(self, model_outputs, pool=None, return_tensors=False):
|
||||
pool = pool if pool is not None else False
|
||||
|
||||
if pool:
|
||||
if "pooler_output" not in model_outputs:
|
||||
raise ValueError(
|
||||
"No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
|
||||
)
|
||||
outputs = model_outputs["pooler_output"]
|
||||
else:
|
||||
# [0] is the first available tensor, logits or last_hidden_state.
|
||||
outputs = model_outputs[0]
|
||||
|
||||
if return_tensors:
|
||||
return model_outputs[0]
|
||||
return outputs
|
||||
if self.framework == "pt":
|
||||
return model_outputs[0].tolist()
|
||||
return outputs.tolist()
|
||||
elif self.framework == "tf":
|
||||
return model_outputs[0].numpy().tolist()
|
||||
return outputs.numpy().tolist()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -62,10 +62,21 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
nested_simplify(outputs[0][0]),
|
||||
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||
|
||||
@require_torch
|
||||
def test_small_model_w_pooler_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="pt"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, pool=True)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs[0]),
|
||||
[-0.056, 0.083, 0.021, 0.038, 0.242, -0.279, -0.033, -0.003, 0.200, -0.192, 0.045, -0.095, -0.077, 0.017, -0.058, -0.063, -0.029, -0.204, 0.014, 0.042, 0.305, -0.205, -0.099, 0.146, -0.287, 0.020, 0.168, -0.052, 0.046, 0.048, -0.156, 0.093]) # fmt: skip
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="tf"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img)
|
||||
@ -73,6 +84,17 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
nested_simplify(outputs[0][0]),
|
||||
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||
|
||||
@require_tf
|
||||
def test_small_model_w_pooler_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="tf"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, pool=True)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs[0]),
|
||||
[-0.056, 0.083, 0.021, 0.038, 0.242, -0.279, -0.033, -0.003, 0.200, -0.192, 0.045, -0.095, -0.077, 0.017, -0.058, -0.063, -0.029, -0.204, 0.014, 0.042, 0.305, -0.205, -0.099, 0.146, -0.287, 0.020, 0.168, -0.052, 0.046, 0.048, -0.156, 0.093]) # fmt: skip
|
||||
|
||||
@require_torch
|
||||
def test_image_processing_small_model_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
@ -91,6 +113,10 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||
|
||||
# Test pooling option
|
||||
outputs = feature_extractor(img, pool=True)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (32,))
|
||||
|
||||
@require_tf
|
||||
def test_image_processing_small_model_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
@ -109,6 +135,10 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||
|
||||
# Test pooling option
|
||||
outputs = feature_extractor(img, pool=True)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (32,))
|
||||
|
||||
@require_torch
|
||||
def test_return_tensors_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
|
Loading…
Reference in New Issue
Block a user