mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add image to image pipeline (#25393)
* Add image to image pipeline Add image to image pipeline * remove swin2sr from tf auto * make ImageToImage importable * make style make style make style make style * remove tf support * remove nonused imports * fix postprocessing * add important comments; add unit tests * add documentation * remove support for TF * make fixup * fix typehint Image.Image * fix documentation code * address review request; fix unittest type checking * address review request; fix unittest type checking * make fixup * address reviews * Update src/transformers/pipelines/image_to_image.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * enhance docs * make style * make style * improve docetest time * improve docetest time * Update tests/pipelines/test_pipelines_image_to_image.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * Update tests/pipelines/test_pipelines_image_to_image.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * make fixup * undo faulty merge * undo faulty merge * add image-to-image to test pipeline mixin * Update src/transformers/pipelines/image_to_image.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/pipelines/test_pipelines_image_to_image.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * improve docs --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
914771cbfe
commit
576cd45a57
@ -352,6 +352,12 @@ Pipelines available for computer vision tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageToImagePipeline
|
||||
|
||||
[[autodoc]] ImageToImagePipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ObjectDetectionPipeline
|
||||
|
||||
[[autodoc]] ObjectDetectionPipeline
|
||||
|
@ -266,6 +266,10 @@ The following auto classes are available for the following computer vision tasks
|
||||
|
||||
[[autodoc]] AutoModelForImageSegmentation
|
||||
|
||||
### AutoModelForImageToImage
|
||||
|
||||
[[autodoc]] AutoModelForImageToImage
|
||||
|
||||
### AutoModelForSemanticSegmentation
|
||||
|
||||
[[autodoc]] AutoModelForSemanticSegmentation
|
||||
|
@ -653,6 +653,7 @@ _import_structure = {
|
||||
"FillMaskPipeline",
|
||||
"ImageClassificationPipeline",
|
||||
"ImageSegmentationPipeline",
|
||||
"ImageToImagePipeline",
|
||||
"ImageToTextPipeline",
|
||||
"JsonPipelineDataFormat",
|
||||
"NerPipeline",
|
||||
@ -1120,6 +1121,7 @@ else:
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
@ -1157,6 +1159,7 @@ else:
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForImageToImage",
|
||||
"AutoModelForInstanceSegmentation",
|
||||
"AutoModelForMaskedImageModeling",
|
||||
"AutoModelForMaskedLM",
|
||||
@ -4740,6 +4743,7 @@ if TYPE_CHECKING:
|
||||
FillMaskPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageToImagePipeline,
|
||||
ImageToTextPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
NerPipeline,
|
||||
@ -5157,6 +5161,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
@ -5194,6 +5199,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForDocumentQuestionAnswering,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForImageToImage,
|
||||
AutoModelForInstanceSegmentation,
|
||||
AutoModelForMaskedImageModeling,
|
||||
AutoModelForMaskedLM,
|
||||
|
@ -50,6 +50,7 @@ else:
|
||||
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
@ -86,6 +87,7 @@ else:
|
||||
"AutoModelForDepthEstimation",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForImageToImage",
|
||||
"AutoModelForInstanceSegmentation",
|
||||
"AutoModelForMaskGeneration",
|
||||
"AutoModelForTextEncoding",
|
||||
@ -230,6 +232,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
@ -267,6 +270,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForDocumentQuestionAnswering,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForImageToImage,
|
||||
AutoModelForInstanceSegmentation,
|
||||
AutoModelForMaskedImageModeling,
|
||||
AutoModelForMaskedLM,
|
||||
|
@ -1112,6 +1112,12 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("swin2sr", "Swin2SRForImageSuperResolution"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@ -1197,6 +1203,8 @@ MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL
|
||||
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
@ -1206,6 +1214,10 @@ class AutoModelForTextEncoding(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
||||
|
||||
|
||||
class AutoModelForImageToImage(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
|
@ -29,7 +29,7 @@ from ..image_processing_utils import BaseImageProcessor
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
|
||||
from ..models.auto.modeling_auto import AutoModelForDepthEstimation
|
||||
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import (
|
||||
@ -64,6 +64,7 @@ from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .image_to_image import ImageToImagePipeline
|
||||
from .image_to_text import ImageToTextPipeline
|
||||
from .mask_generation import MaskGenerationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
@ -394,6 +395,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"image-to-image": {
|
||||
"impl": ImageToImagePipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageToImage,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}},
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
@ -472,6 +480,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
|
||||
- `"image-classification"`
|
||||
- `"image-segmentation"`
|
||||
- `"image-to-text"`
|
||||
- `"image-to-image"`
|
||||
- `"object-detection"`
|
||||
- `"question-answering"`
|
||||
- `"summarization"`
|
||||
@ -556,6 +565,7 @@ def pipeline(
|
||||
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
|
||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
||||
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
||||
|
134
src/transformers/pipelines/image_to_image.py
Normal file
134
src/transformers/pipelines/image_to_image.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ImageToImagePipeline(Pipeline):
|
||||
"""
|
||||
Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous
|
||||
image input.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> upscaler = pipeline("image-to-image", model="caidas/swin2SR-classical-sr-x2-64")
|
||||
>>> img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
>>> img = img.resize((64, 64))
|
||||
>>> upscaled_img = upscaler(img)
|
||||
>>> img.size
|
||||
(64, 64)
|
||||
|
||||
>>> upscaled_img.size
|
||||
(144, 144)
|
||||
```
|
||||
|
||||
This image to image pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"image-to-image"`.
|
||||
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=image-to-image).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
forward_params = {}
|
||||
|
||||
if "timeout" in kwargs:
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
if "head_mask" in kwargs:
|
||||
forward_params["head_mask"] = kwargs["head_mask"]
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(
|
||||
self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs
|
||||
) -> Union["Image.Image", List["Image.Image"]]:
|
||||
"""
|
||||
Transform the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
An image (Image.Image) or a list of images (List["Image.Image"]) containing result(s). If the input is a
|
||||
single image, the return will be also a single image, if the input is a list of several images, it will
|
||||
return a list of transformed images.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
inputs = self.image_processor(images=[image], return_tensors="pt")
|
||||
return inputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
images = []
|
||||
if "reconstruction" in model_outputs.keys():
|
||||
outputs = model_outputs.reconstruction
|
||||
for output in outputs:
|
||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = np.moveaxis(output, source=0, destination=-1)
|
||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||
images.append(Image.fromarray(output))
|
||||
|
||||
return images if len(images) > 1 else images[0]
|
@ -573,6 +573,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
||||
|
||||
|
||||
@ -728,6 +731,13 @@ class AutoModelForImageSegmentation(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForImageToImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForInstanceSegmentation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
85
tests/pipelines/test_pipelines_image_to_image.py
Normal file
85
tests/pipelines/test_pipelines_image_to_image.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageToImage,
|
||||
ImageToImagePipeline,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ImageToImagePipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
]
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
model_id = "caidas/swin2SR-classical-sr-x2-64"
|
||||
upscaler = pipeline("image-to-image", model=model_id)
|
||||
upscaled_list = upscaler(self.examples)
|
||||
|
||||
self.assertEqual(len(upscaled_list), len(self.examples))
|
||||
for output in upscaled_list:
|
||||
self.assertIsInstance(output, Image.Image)
|
||||
|
||||
self.assertEqual(upscaled_list[0].size, (1296, 976))
|
||||
self.assertEqual(upscaled_list[1].size, (1296, 976))
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@slow
|
||||
def test_pipeline_model_processor(self):
|
||||
model_id = "caidas/swin2SR-classical-sr-x2-64"
|
||||
model = AutoModelForImageToImage.from_pretrained(model_id)
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
|
||||
upscaler = ImageToImagePipeline(model=model, image_processor=image_processor)
|
||||
upscaled_list = upscaler(self.examples)
|
||||
|
||||
self.assertEqual(len(upscaled_list), len(self.examples))
|
||||
for output in upscaled_list:
|
||||
self.assertIsInstance(output, Image.Image)
|
||||
|
||||
self.assertEqual(upscaled_list[0].size, (1296, 976))
|
||||
self.assertEqual(upscaled_list[1].size, (1296, 976))
|
@ -40,6 +40,7 @@ from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipeli
|
||||
from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
|
||||
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
|
||||
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
|
||||
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
|
||||
from .pipelines.test_pipelines_mask_generation import MaskGenerationPipelineTests
|
||||
from .pipelines.test_pipelines_object_detection import ObjectDetectionPipelineTests
|
||||
@ -70,6 +71,7 @@ pipeline_test_mapping = {
|
||||
"fill-mask": {"test": FillMaskPipelineTests},
|
||||
"image-classification": {"test": ImageClassificationPipelineTests},
|
||||
"image-segmentation": {"test": ImageSegmentationPipelineTests},
|
||||
"image-to-image": {"test": ImageToImagePipelineTests},
|
||||
"image-to-text": {"test": ImageToTextPipelineTests},
|
||||
"mask-generation": {"test": MaskGenerationPipelineTests},
|
||||
"object-detection": {"test": ObjectDetectionPipelineTests},
|
||||
|
@ -67,6 +67,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
|
||||
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
|
||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||
("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"),
|
||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||
(
|
||||
|
Loading…
Reference in New Issue
Block a user