diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index d21d1d3072f..58db3ed3f4d 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -303,8 +303,16 @@ class FeaturesManager: The instance of the model. """ + # If PyTorch and TensorFlow are installed in the same environment, we + # load an AutoModel class by default model_class = FeaturesManager.get_model_class_for_feature(feature) - return model_class.from_pretrained(model) + try: + model = model_class.from_pretrained(model) + # Load TensorFlow weights in an AutoModel instance if PyTorch and + # TensorFlow are installed in the same environment + except OSError: + model = model_class.from_pretrained(model, from_tf=True) + return model @staticmethod def check_supported_model_or_raise(