Enable ONNX export when PyTorch and TensorFlow installed in the same environment (#15625)

This commit is contained in:
lewtun 2022-02-11 16:25:06 +01:00 committed by GitHub
parent 6cf06d198c
commit 7e4844fc2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -303,8 +303,16 @@ class FeaturesManager:
The instance of the model.
"""
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature)
return model_class.from_pretrained(model)
try:
model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError:
model = model_class.from_pretrained(model, from_tf=True)
return model
@staticmethod
def check_supported_model_or_raise(