diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 1c8d10939a1..bbf06b07929 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -28,6 +28,7 @@ from .utils import ParameterFormat, compute_effective_axis_dimension, compute_se if TYPE_CHECKING: from ..configuration_utils import PretrainedConfig from ..feature_extraction_utils import FeatureExtractionMixin + from ..image_processing_utils import ImageProcessingMixin from ..tokenization_utils_base import PreTrainedTokenizerBase @@ -278,7 +279,7 @@ class OnnxConfig(ABC): def generate_dummy_inputs( self, - preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin", "ImageProcessingMixin"], batch_size: int = -1, seq_length: int = -1, num_choices: int = -1, @@ -296,7 +297,7 @@ class OnnxConfig(ABC): Generate inputs to provide to the ONNX exporter for the specific framework Args: - preprocessor: ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]): + preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]): The preprocessor associated with this model configuration. batch_size (`int`, *optional*, defaults to -1): The batch size to export the model for (-1 means dynamic axis). @@ -325,6 +326,7 @@ class OnnxConfig(ABC): Mapping[str, Tensor] holding the kwargs to provide to the model's forward function """ from ..feature_extraction_utils import FeatureExtractionMixin + from ..image_processing_utils import ImageProcessingMixin from ..tokenization_utils_base import PreTrainedTokenizerBase if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: @@ -368,6 +370,16 @@ class OnnxConfig(ABC): tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)] return dict(tokenized_input.convert_to_tensors(tensor_type=framework)) return dict(preprocessor(dummy_input, return_tensors=framework)) + elif isinstance(preprocessor, ImageProcessingMixin): + if preprocessor.model_input_names[0] != "pixel_values": + raise ValueError( + f"The `preprocessor` is an image processor ({preprocessor.__class__.__name__}) and expects" + f' `model_input_names[0]` to be "pixel_values", but got {preprocessor.model_input_names[0]}' + ) + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + return dict(preprocessor(images=dummy_input, return_tensors=framework)) elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)