Task-specific pipeline init args (#28439)

* Abstract out pipeline init args

* Address PR comments

* Reword

* BC PIPELINE_INIT_ARGS

* Remove old arguments

* Small fix
This commit is contained in:
amyeroberts 2024-01-30 16:54:57 +00:00 committed by GitHub
parent 2fa1c808ae
commit 1d489b3e61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 112 additions and 112 deletions

View File

@ -18,7 +18,7 @@ import numpy as np
import requests
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_torch_available():
@ -63,7 +63,7 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
return audio
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True))
class AudioClassificationPipeline(Pipeline):
"""
Audio classification pipeline using any `AutoModelForAudioClassification`. This pipeline predicts the class of a

View File

@ -702,14 +702,33 @@ class _ScikitCompat(ABC):
raise NotImplementedError()
PIPELINE_INIT_ARGS = r"""
def build_pipeline_init_args(
has_tokenizer: bool = False,
has_feature_extractor: bool = False,
has_image_processor: bool = False,
supports_binary_output: bool = True,
) -> str:
docstring = r"""
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow."""
if has_tokenizer:
docstring += r"""
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
[`PreTrainedTokenizer`]."""
if has_feature_extractor:
docstring += r"""
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode data for the model. This object inherits from
[`SequenceFeatureExtractor`]."""
if has_image_processor:
docstring += r"""
image_processor ([`BaseImageProcessor`]):
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
[`BaseImageProcessor`]."""
docstring += r"""
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
framework (`str`, *optional*):
@ -732,10 +751,22 @@ PIPELINE_INIT_ARGS = r"""
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
the associated CUDA device id. You can pass native `torch.device` or a `str` too
torch_dtype (`str` or `torch.dtype`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`)"""
if supports_binary_output:
docstring += r"""
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Flag indicating if the output the pipeline should happen in a serialized format (i.e., pickle) or as
the raw output data e.g. text."""
return docstring
PIPELINE_INIT_ARGS = build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
)
if is_torch_available():
from transformers.pipelines.pt_utils import (
@ -746,7 +777,7 @@ if is_torch_available():
)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
class Pipeline(_ScikitCompat):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across

View File

@ -2,7 +2,7 @@ import uuid
from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_tf_available():
@ -192,13 +192,12 @@ class Conversation:
@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
min_length_for_response (`int`, *optional*, defaults to 32):
The minimum length (in number of tokens) for a response.
minimum_tokens (`int`, *optional*, defaults to 10):
The minimum length of tokens to leave for a response.
""",
The minimum length of tokens to leave for a response.""",
)
class ConversationalPipeline(Pipeline):
"""

View File

@ -3,7 +3,7 @@ from typing import List, Union
import numpy as np
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -19,7 +19,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class DepthEstimationPipeline(Pipeline):
"""
Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.

View File

@ -25,7 +25,7 @@ from ..utils import (
is_vision_available,
logging,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import ChunkPipeline, build_pipeline_init_args
from .question_answering import select_starts_ends
@ -98,7 +98,7 @@ class ModelType(ExplicitEnum):
VisionEncoderDecoder = "vision_encoder_decoder"
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True, has_tokenizer=True))
class DocumentQuestionAnsweringPipeline(ChunkPipeline):
# TODO: Update task_summary docs to include an example with document QA and then update the first sentence
"""

View File

@ -1,9 +1,17 @@
from typing import Dict
from .base import GenericTensor, Pipeline
from ..utils import add_end_docstrings
from .base import GenericTensor, Pipeline, build_pipeline_init_args
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
r"""
tokenize_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.
return_tensors (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
)
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
@ -27,34 +35,6 @@ class FeatureExtractionPipeline(Pipeline):
All models may be used for this pipeline. See a list of all models, including community-contributed models on
[huggingface.co/models](https://huggingface.co/models).
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
provided.
return_tensors (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.
task (`str`, defaults to `""`):
A task-identifier for the pipeline.
args_parser ([`~pipelines.ArgumentHandler`], *optional*):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
tokenize_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.
"""
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):

View File

@ -3,7 +3,7 @@ from typing import Dict
import numpy as np
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException
from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args
if is_tf_available():
@ -20,7 +20,7 @@ logger = logging.get_logger(__name__)
@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
top_k (`int`, defaults to 5):
The number of predictions to return.
@ -28,8 +28,8 @@ logger = logging.get_logger(__name__)
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
token will be used (with a warning, and that might be slower).
""",
tokenizer_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.""",
)
class FillMaskPipeline(Pipeline):
"""

View File

@ -11,7 +11,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -48,7 +48,7 @@ class ClassificationFunction(ExplicitEnum):
@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_image_processor=True),
r"""
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
@ -57,8 +57,7 @@ class ClassificationFunction(ExplicitEnum):
has several labels, will apply the softmax function on the output.
- `"sigmoid"`: Applies the sigmoid function on the output.
- `"softmax"`: Applies the softmax function on the output.
- `"none"`: Does not apply any function on the output.
""",
- `"none"`: Does not apply any function on the output.""",
)
class ImageClassificationPipeline(Pipeline):
"""

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Union
import numpy as np
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -27,7 +27,7 @@ Prediction = Dict[str, Any]
Predictions = List[Prediction]
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ImageSegmentationPipeline(Pipeline):
"""
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and

View File

@ -22,7 +22,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -36,7 +36,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ImageToImagePipeline(Pipeline):
"""
Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous

View File

@ -8,7 +8,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -27,7 +27,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
class ImageToTextPipeline(Pipeline):
"""
Image To Text pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given image.

View File

@ -8,7 +8,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import ChunkPipeline, build_pipeline_init_args
if is_torch_available():
@ -19,7 +19,17 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(
build_pipeline_init_args(has_image_processor=True),
r"""
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format""",
)
class MaskGenerationPipeline(ChunkPipeline):
"""
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
@ -48,23 +58,6 @@ class MaskGenerationPipeline(ChunkPipeline):
applies a variety of filters based on non maximum suppression to remove bad masks.
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode the input.
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format
Example:
```python

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -23,7 +23,7 @@ Prediction = Dict[str, Any]
Predictions = List[Prediction]
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ObjectDetectionPipeline(Pipeline):
"""
Object detection pipeline using any `AutoModelForObjectDetection`. This pipeline predicts bounding boxes of objects

View File

@ -17,7 +17,7 @@ from ..utils import (
is_torch_available,
logging,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args
logger = logging.get_logger(__name__)
@ -221,7 +221,7 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
return inputs
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class QuestionAnsweringPipeline(ChunkPipeline):
"""
Question Answering pipeline using any `ModelForQuestionAnswering`. See the [question answering

View File

@ -10,7 +10,7 @@ from ..utils import (
is_torch_available,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline, PipelineException
from .base import ArgumentHandler, Dataset, Pipeline, PipelineException, build_pipeline_init_args
if is_torch_available():
@ -84,7 +84,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
return tqa_pipeline_inputs
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TableQuestionAnsweringPipeline(Pipeline):
"""
Table Question Answering pipeline using a `ModelForTableQuestionAnswering`. This pipeline is only available in

View File

@ -3,7 +3,7 @@ import warnings
from ..tokenization_utils import TruncationStrategy
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_tf_available():
@ -22,7 +22,7 @@ class ReturnType(enum.Enum):
TEXT = 1
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class Text2TextGenerationPipeline(Pipeline):
"""
Pipeline for text to text generation using seq2seq models.
@ -213,7 +213,7 @@ class Text2TextGenerationPipeline(Pipeline):
return records
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class SummarizationPipeline(Text2TextGenerationPipeline):
"""
Summarize news articles and other documents.
@ -283,7 +283,7 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TranslationPipeline(Text2TextGenerationPipeline):
"""
Translates from one language to another.

View File

@ -5,7 +5,7 @@ from typing import Dict
import numpy as np
from ..utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline
from .base import GenericTensor, Pipeline, build_pipeline_init_args
if is_tf_available():
@ -32,7 +32,7 @@ class ClassificationFunction(ExplicitEnum):
@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
return_all_scores (`bool`, *optional*, defaults to `False`):
Whether to return all prediction scores or just the one of the predicted class.
@ -43,8 +43,7 @@ class ClassificationFunction(ExplicitEnum):
has several labels, will apply the softmax function on the output.
- `"sigmoid"`: Applies the sigmoid function on the output.
- `"softmax"`: Applies the softmax function on the output.
- `"none"`: Does not apply any function on the output.
""",
- `"none"`: Does not apply any function on the output.""",
)
class TextClassificationPipeline(Pipeline):
"""

View File

@ -2,7 +2,7 @@ import enum
import warnings
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_torch_available():
@ -20,7 +20,7 @@ class ReturnType(enum.Enum):
FULL_TEXT = 2
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a

View File

@ -11,7 +11,7 @@ from ..utils import (
is_tf_available,
is_torch_available,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset
from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
if is_tf_available():
@ -59,7 +59,7 @@ class AggregationStrategy(ExplicitEnum):
@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
ignore_labels (`List[str]`, defaults to `["O"]`):
A list of labels to ignore.
@ -90,8 +90,7 @@ class AggregationStrategy(ExplicitEnum):
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Word entity will simply be the token with the maximum score.
""",
end up with different tags. Word entity will simply be the token with the maximum score.""",
)
class TokenClassificationPipeline(ChunkPipeline):
"""

View File

@ -4,7 +4,7 @@ from typing import List, Union
import requests
from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_decord_available():
@ -18,7 +18,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class VideoClassificationPipeline(Pipeline):
"""
Video classification pipeline using any `AutoModelForVideoClassification`. This pipeline predicts the class of a

View File

@ -1,7 +1,7 @@
from typing import Union
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -15,7 +15,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
class VisualQuestionAnsweringPipeline(Pipeline):
"""
Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only

View File

@ -23,13 +23,13 @@ from ..utils import (
logging,
)
from .audio_classification import ffmpeg_read
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True, has_tokenizer=True))
class ZeroShotAudioClassificationPipeline(Pipeline):
"""
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you

View File

@ -5,7 +5,7 @@ import numpy as np
from ..tokenization_utils import TruncationStrategy
from ..utils import add_end_docstrings, logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args
logger = logging.get_logger(__name__)
@ -43,7 +43,7 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
return sequence_pairs, sequences
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class ZeroShotClassificationPipeline(ChunkPipeline):
"""
NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural

View File

@ -9,7 +9,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
@ -29,7 +29,7 @@ if is_tf_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ZeroShotImageClassificationPipeline(Pipeline):
"""
Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import ChunkPipeline, build_pipeline_init_args
if is_vision_available():
@ -19,7 +19,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ZeroShotObjectDetectionPipeline(ChunkPipeline):
"""
Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of