diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 09672ef9afb..6e7e9f154e6 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -41,7 +41,7 @@ from .configuration_vilt import ViltConfig logger = logging.get_logger(__name__) -if torch.__version__ < (1, 10, 0): +if version.parse(torch.__version__) < version.parse("1.10.0"): logger.warning( f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use " "ViltModel. Please upgrade torch."