Update OnnxConfig.generate_dummy_inputs to check ImageProcessingMixin (#20157)

* Check ImageProcessingMixin in OnnxConfig.generate_dummy_inputs

* Check ImageProcessingMixin in OnnxConfig.generate_dummy_inputs

* Add back

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-11-10 16:04:51 +01:00 committed by GitHub
parent daf4436e07
commit e0d7c831c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,6 +28,7 @@ from .utils import ParameterFormat, compute_effective_axis_dimension, compute_se
if TYPE_CHECKING:
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import FeatureExtractionMixin
from ..image_processing_utils import ImageProcessingMixin
from ..tokenization_utils_base import PreTrainedTokenizerBase
@ -278,7 +279,7 @@ class OnnxConfig(ABC):
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin", "ImageProcessingMixin"],
batch_size: int = -1,
seq_length: int = -1,
num_choices: int = -1,
@ -296,7 +297,7 @@ class OnnxConfig(ABC):
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
preprocessor: ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]):
The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
@ -325,6 +326,7 @@ class OnnxConfig(ABC):
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
from ..feature_extraction_utils import FeatureExtractionMixin
from ..image_processing_utils import ImageProcessingMixin
from ..tokenization_utils_base import PreTrainedTokenizerBase
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
@ -368,6 +370,16 @@ class OnnxConfig(ABC):
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
return dict(preprocessor(dummy_input, return_tensors=framework))
elif isinstance(preprocessor, ImageProcessingMixin):
if preprocessor.model_input_names[0] != "pixel_values":
raise ValueError(
f"The `preprocessor` is an image processor ({preprocessor.__class__.__name__}) and expects"
f' `model_input_names[0]` to be "pixel_values", but got {preprocessor.model_input_names[0]}'
)
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
return dict(preprocessor(images=dummy_input, return_tensors=framework))
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)