Add timeout parameter to load_image function (#25184)

* Add timeout parameter to load_image function.

* Remove line.

* Reformat code

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Add parameter to docs.

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Roland Szabo 2023-08-03 14:51:54 +00:00 committed by GitHub
parent 6d3f9c1e2e
commit d114a6b71f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 102 additions and 29 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import requests
@ -253,13 +253,15 @@ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List,
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.
Returns:
`PIL.Image.Image`: A PIL Image.
@ -269,7 +271,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = PIL.Image.open(requests.get(image, stream=True).raw)
image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:

View File

@ -68,6 +68,9 @@ class DepthEstimationPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
@ -81,11 +84,14 @@ class DepthEstimationPipeline(Pipeline):
"""
return super().__call__(images, **kwargs)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def _sanitize_parameters(self, timeout=None, **kwargs):
preprocess_params = {}
if timeout is not None:
preprocess_params["timeout"] = timeout
return preprocess_params, {}, {}
def preprocess(self, image):
image = load_image(image)
def preprocess(self, image, timeout=None):
image = load_image(image, timeout)
self.image_size = image.size
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
return model_inputs

View File

@ -159,6 +159,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
max_seq_len=None,
top_k=None,
handle_impossible_answer=None,
timeout=None,
**kwargs,
):
preprocess_params, postprocess_params = {}, {}
@ -174,6 +175,8 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
preprocess_params["lang"] = lang
if tesseract_config is not None:
preprocess_params["tesseract_config"] = tesseract_config
if timeout is not None:
preprocess_params["timeout"] = timeout
if top_k is not None:
if top_k < 1:
@ -244,6 +247,9 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
Language to use while running OCR. Defaults to english.
tesseract_config (`str`, *optional*):
Additional flags to pass to tesseract while running OCR.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
@ -273,6 +279,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
word_boxes: Tuple[str, List[float]] = None,
lang=None,
tesseract_config="",
timeout=None,
):
# NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
# to support documents with enough tokens that overflow the model's window
@ -285,7 +292,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
image = None
image_features = {}
if input.get("image", None) is not None:
image = load_image(input["image"])
image = load_image(input["image"], timeout=timeout)
if self.image_processor is not None:
image_features.update(self.image_processor(images=image, return_tensors=self.framework))
elif self.feature_extractor is not None:

View File

@ -62,11 +62,14 @@ class ImageClassificationPipeline(Pipeline):
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
def _sanitize_parameters(self, top_k=None):
def _sanitize_parameters(self, top_k=None, timeout=None):
preprocess_params = {}
if timeout is not None:
preprocess_params["timeout"] = timeout
postprocess_params = {}
if top_k is not None:
postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params
return preprocess_params, {}, postprocess_params
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
@ -86,6 +89,9 @@ class ImageClassificationPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
@ -99,8 +105,8 @@ class ImageClassificationPipeline(Pipeline):
"""
return super().__call__(images, **kwargs)
def preprocess(self, image):
image = load_image(image)
def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
return model_inputs

View File

@ -89,6 +89,8 @@ class ImageSegmentationPipeline(Pipeline):
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
if "overlap_mask_area_threshold" in kwargs:
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
if "timeout" in kwargs:
preprocess_kwargs["timeout"] = kwargs["timeout"]
return preprocess_kwargs, {}, postprocess_kwargs
@ -116,6 +118,9 @@ class ImageSegmentationPipeline(Pipeline):
Threshold to use when turning the predicted masks into binary values.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
Mask overlap threshold to eliminate small, disconnected segments.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
@ -133,8 +138,8 @@ class ImageSegmentationPipeline(Pipeline):
"""
return super().__call__(images, **kwargs)
def preprocess(self, image, subtask=None):
image = load_image(image)
def preprocess(self, image, subtask=None, timeout=None):
image = load_image(image, timeout=timeout)
target_size = [(image.height, image.width)]
if self.model.config.__class__.__name__ == "OneFormerConfig":
if subtask is None:

View File

@ -58,12 +58,14 @@ class ImageToTextPipeline(Pipeline):
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
)
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
forward_kwargs = {}
preprocess_params = {}
if prompt is not None:
preprocess_params["prompt"] = prompt
if timeout is not None:
preprocess_params["timeout"] = timeout
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
@ -97,6 +99,9 @@ class ImageToTextPipeline(Pipeline):
generate_kwargs (`Dict`, *optional*):
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
@ -105,8 +110,8 @@ class ImageToTextPipeline(Pipeline):
"""
return super().__call__(images, **kwargs)
def preprocess(self, image, prompt=None):
image = load_image(image)
def preprocess(self, image, prompt=None, timeout=None):
image = load_image(image, timeout=timeout)
if prompt is not None:
if not isinstance(prompt, str):

View File

@ -113,6 +113,8 @@ class MaskGenerationPipeline(ChunkPipeline):
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
if "crop_n_points_downscale_factor" in kwargs:
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
if "timeout" in kwargs:
preprocess_kwargs["timeout"] = kwargs["timeout"]
# postprocess args
if "pred_iou_thresh" in kwargs:
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
@ -156,6 +158,9 @@ class MaskGenerationPipeline(ChunkPipeline):
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
`Dict`: A dictionary with the following keys:
@ -175,8 +180,9 @@ class MaskGenerationPipeline(ChunkPipeline):
crop_overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[int] = 1,
timeout: Optional[float] = None,
):
image = load_image(image)
image = load_image(image, timeout=timeout)
target_size = self.image_processor.size["longest_edge"]
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor

View File

@ -61,10 +61,13 @@ class ObjectDetectionPipeline(Pipeline):
self.check_model_type(mapping)
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
postprocess_kwargs = {}
if "threshold" in kwargs:
postprocess_kwargs["threshold"] = kwargs["threshold"]
return {}, {}, postprocess_kwargs
return preprocess_params, {}, postprocess_kwargs
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
"""
@ -82,6 +85,9 @@ class ObjectDetectionPipeline(Pipeline):
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
threshold (`float`, *optional*, defaults to 0.9):
The probability necessary to make a prediction.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
@ -97,8 +103,8 @@ class ObjectDetectionPipeline(Pipeline):
return super().__call__(*args, **kwargs)
def preprocess(self, image):
image = load_image(image)
def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.image_processor(images=[image], return_tensors="pt")
if self.tokenizer is not None:

View File

@ -55,12 +55,14 @@ class VisualQuestionAnsweringPipeline(Pipeline):
super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, timeout=None, **kwargs):
preprocess_params, postprocess_params = {}, {}
if padding is not None:
preprocess_params["padding"] = padding
if truncation is not None:
preprocess_params["truncation"] = truncation
if timeout is not None:
preprocess_params["timeout"] = timeout
if top_k is not None:
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
@ -90,6 +92,9 @@ class VisualQuestionAnsweringPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
@ -109,8 +114,8 @@ class VisualQuestionAnsweringPipeline(Pipeline):
results = super().__call__(inputs, **kwargs)
return results
def preprocess(self, inputs, padding=False, truncation=False):
image = load_image(inputs["image"])
def preprocess(self, inputs, padding=False, truncation=False, timeout=None):
image = load_image(inputs["image"], timeout=timeout)
model_inputs = self.tokenizer(
inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
)

View File

@ -91,6 +91,10 @@ class ZeroShotImageClassificationPipeline(Pipeline):
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
logits_per_image
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
following keys:
@ -104,13 +108,15 @@ class ZeroShotImageClassificationPipeline(Pipeline):
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
return preprocess_params, {}, {}
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."):
image = load_image(image)
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors=self.framework)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]

View File

@ -111,6 +111,10 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
The number of top predictions that will be returned by the pipeline. If the provided number is `None`
or higher than the number of predictions available, it will default to the number of predictions.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return:
A list of lists containing prediction results, one list per input image. Each list contains dictionaries
@ -132,15 +136,18 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
return results
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
postprocess_params = {}
if "threshold" in kwargs:
postprocess_params["threshold"] = kwargs["threshold"]
if "top_k" in kwargs:
postprocess_params["top_k"] = kwargs["top_k"]
return {}, {}, postprocess_params
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs):
image = load_image(inputs["image"])
def preprocess(self, inputs, timeout=None):
image = load_image(inputs["image"], timeout=timeout)
candidate_labels = inputs["candidate_labels"]
if isinstance(candidate_labels, str):
candidate_labels = candidate_labels.split(",")

View File

@ -18,7 +18,9 @@ import unittest
import datasets
import numpy as np
import pytest
from requests import ReadTimeout
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
from transformers import is_torch_available, is_vision_available
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images
from transformers.testing_utils import require_torch, require_vision
@ -478,6 +480,16 @@ class ImageFeatureExtractionTester(unittest.TestCase):
@require_vision
class LoadImageTester(unittest.TestCase):
def test_load_img_url(self):
img = load_image(INVOICE_URL)
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (1061, 750, 3))
def test_load_img_url_timeout(self):
with self.assertRaises(ReadTimeout):
load_image(INVOICE_URL, timeout=0.001)
def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img)