mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Support loading base64 images in pipelines (#25633)
* support loading base64 images * add test * mention in docs * remove the logging * sort imports * update error message * Update tests/utils/test_image_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * restructure to catch base64 exception * doesn't like the newline * download files * format * optimize imports * guess it needs a space? * support loading base64 images * add test * remove the logging * sort imports * restructure to catch base64 exception * doesn't like the newline * download files * optimize imports * guess it needs a space? --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
ce2d4bc6a1
commit
dbc16f4404
@ -204,7 +204,7 @@ page.
|
||||
|
||||
Using a [`pipeline`] for vision tasks is practically identical.
|
||||
|
||||
Specify your task and pass your image to the classifier. The image can be a link or a local path to the image. For example, what species of cat is shown below?
|
||||
Specify your task and pass your image to the classifier. The image can be a link, a local path or a base64-encoded image. For example, what species of cat is shown below?
|
||||
|
||||

|
||||
|
||||
|
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -298,14 +300,22 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
elif os.path.isfile(image):
|
||||
image = PIL.Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
if image.startswith("data:image/"):
|
||||
image = image.split(",")[1]
|
||||
|
||||
# Try to load as base64
|
||||
try:
|
||||
b64 = base64.b64decode(image, validate=True)
|
||||
image = PIL.Image.open(BytesIO(b64))
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
||||
)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = image
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
||||
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
||||
)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
|
@ -13,11 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests import ReadTimeout
|
||||
|
||||
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
|
||||
@ -500,6 +503,40 @@ class LoadImageTester(unittest.TestCase):
|
||||
(480, 640, 3),
|
||||
)
|
||||
|
||||
def test_load_img_base64_prefix(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f
|
||||
)
|
||||
|
||||
with open(tmp_file, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_base64(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f
|
||||
)
|
||||
|
||||
with open(tmp_file, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_rgba(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user