Support for version spec in requires & arbitrary mismatching depths across folders (#37854)

* Support for version spec in requires & arbitrary mismatching depths

* Quality

* Testing
This commit is contained in:
Lysandre Debut 2025-05-09 15:26:27 +02:00 committed by GitHub
parent 774dc274ac
commit 23d79cea75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 391 additions and 45 deletions

View File

@ -84,6 +84,19 @@ class Trainer:
Backends that can be added here are all the backends that are available in the `import_utils.py` module.
Additionally, specific versions can be specified in each backend. For example, this is how you would specify
a requirement on torch>=2.6 on the `Trainer` class:
```python
from .utils.import_utils import requires
@requires(backends=("torch>=2.6", "accelerate"))
class Trainer:
...
```
You can specify the following operators: `==`, `>`, `>=`, `<`, `<=`, `!=`.
## Methods
[[autodoc]] utils.import_utils.define_import_structure

View File

@ -19,16 +19,19 @@ import importlib.machinery
import importlib.metadata
import importlib.util
import json
import operator
import os
import re
import shutil
import subprocess
import sys
import warnings
from collections import OrderedDict
from enum import Enum
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Union
from packaging import version
@ -1838,8 +1841,16 @@ def requires_backends(obj, backends):
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
failed = []
for backend in backends:
if isinstance(backend, Backend):
available, msg = backend.is_satisfied, backend.error_message
else:
available, msg = BACKENDS_MAPPING[backend]
if not available():
failed.append(msg.format(name))
if failed:
raise ImportError("".join(failed))
@ -1884,10 +1895,13 @@ class _LazyModule(ModuleType):
import_structure: IMPORT_STRUCTURE_T,
module_spec: Optional[importlib.machinery.ModuleSpec] = None,
extra_objects: Optional[Dict[str, object]] = None,
explicit_import_shortcut: Optional[Dict[str, List[str]]] = None,
):
super().__init__(name)
self._object_missing_backend = {}
self._explicit_import_shortcut = explicit_import_shortcut if explicit_import_shortcut else {}
if any(isinstance(key, frozenset) for key in import_structure.keys()):
self._modules = set()
self._class_to_module = {}
@ -1916,14 +1930,25 @@ class _LazyModule(ModuleType):
module_keys = set(
chain(*[[k.rsplit(".", i)[0] for i in range(k.count(".") + 1)] for k in list(module.keys())])
)
for backend in backends:
if backend not in BACKENDS_MAPPING:
raise ValueError(
f"Error: the following backend: '{backend}' was specified around object {module} but isn't specified in the backends mapping."
)
callable, error = BACKENDS_MAPPING[backend]
if not callable():
if backend in BACKENDS_MAPPING:
callable, _ = BACKENDS_MAPPING[backend]
else:
if any(key in backend for key in ["=", "<", ">"]):
backend = Backend(backend)
callable = backend.is_satisfied
else:
raise ValueError(
f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}"
)
try:
if not callable():
missing_backends.append(backend)
except (importlib.metadata.PackageNotFoundError, ModuleNotFoundError, RuntimeError):
missing_backends.append(backend)
self._modules = self._modules.union(module_keys)
for key, values in module.items():
@ -2000,12 +2025,29 @@ class _LazyModule(ModuleType):
value = Placeholder
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
try:
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
except (ModuleNotFoundError, RuntimeError) as e:
raise ModuleNotFoundError(
f"Could not import module '{name}'. Are this object's requirements defined correctly?"
) from e
elif name in self._modules:
value = self._get_module(name)
try:
value = self._get_module(name)
except (ModuleNotFoundError, RuntimeError) as e:
raise ModuleNotFoundError(
f"Could not import module '{name}'. Are this object's requirements defined correctly?"
) from e
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
value = None
for key, values in self._explicit_import_shortcut.items():
if name in values:
value = self._get_module(key)
if value is None:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
@ -2046,6 +2088,64 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
return module
class VersionComparison(Enum):
EQUAL = operator.eq
NOT_EQUAL = operator.ne
GREATER_THAN = operator.gt
LESS_THAN = operator.lt
GREATER_THAN_OR_EQUAL = operator.ge
LESS_THAN_OR_EQUAL = operator.le
@staticmethod
def from_string(version_string: str) -> "VersionComparison":
string_to_operator = {
"=": VersionComparison.EQUAL.value,
"==": VersionComparison.EQUAL.value,
"!=": VersionComparison.NOT_EQUAL.value,
">": VersionComparison.GREATER_THAN.value,
"<": VersionComparison.LESS_THAN.value,
">=": VersionComparison.GREATER_THAN_OR_EQUAL.value,
"<=": VersionComparison.LESS_THAN_OR_EQUAL.value,
}
return string_to_operator[version_string]
@lru_cache()
def split_package_version(package_version_str) -> Tuple[str, str, str]:
pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
match = re.match(pattern, package_version_str)
if match:
return (match.group(1), match.group(2), match.group(3))
else:
raise ValueError(f"Invalid package version string: {package_version_str}")
class Backend:
def __init__(self, backend_requirement: str):
self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement)
if self.package_name not in BACKENDS_MAPPING:
raise ValueError(
f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}"
)
def is_satisfied(self) -> bool:
return VersionComparison.from_string(self.version_comparison)(
version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version)
)
def __repr__(self) -> str:
return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")'
@property
def error_message(self):
return (
f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That"
f" library was not found with this version in your environment."
)
def requires(*, backends=()):
"""
This decorator enables two things:
@ -2053,15 +2153,22 @@ def requires(*, backends=()):
to execute correctly without instantiating it
- The '@requires' string is used to dynamically import objects
"""
for backend in backends:
if backend not in BACKENDS_MAPPING:
raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}")
if not isinstance(backends, tuple):
raise ValueError("Backends should be a tuple.")
applied_backends = []
for backend in backends:
if backend in BACKENDS_MAPPING:
applied_backends.append(backend)
else:
if any(key in backend for key in ["=", "<", ">"]):
applied_backends.append(Backend(backend))
else:
raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}")
def inner_fn(fun):
fun.__backends = backends
fun.__backends = applied_backends
return fun
return inner_fn
@ -2369,23 +2476,53 @@ def spread_import_structure(nested_import_structure):
"""
def propagate_frozenset(unordered_import_structure):
tuple_first_import_structure = {}
frozenset_first_import_structure = {}
for _key, _value in unordered_import_structure.items():
# If the value is not a dict but a string, no need for custom manipulation
if not isinstance(_value, dict):
tuple_first_import_structure[_key] = _value
frozenset_first_import_structure[_key] = _value
elif any(isinstance(v, frozenset) for v in _value.keys()):
# Here we want to switch around key and v
for k, v in _value.items():
if isinstance(k, frozenset):
if k not in tuple_first_import_structure:
tuple_first_import_structure[k] = {}
tuple_first_import_structure[k][_key] = v
# Here we want to switch around _key and k to propagate k upstream if it is a frozenset
if k not in frozenset_first_import_structure:
frozenset_first_import_structure[k] = {}
if _key not in frozenset_first_import_structure[k]:
frozenset_first_import_structure[k][_key] = {}
frozenset_first_import_structure[k][_key].update(v)
else:
# If k is not a frozenset, it means that the dictionary is not "level": some keys (top-level)
# are frozensets, whereas some are not -> frozenset keys are at an unkown depth-level of the
# dictionary.
#
# We recursively propagate the frozenset for this specific dictionary so that the frozensets
# are at the top-level when we handle them.
propagated_frozenset = propagate_frozenset({k: v})
for r_k, r_v in propagated_frozenset.items():
if isinstance(_key, frozenset):
if r_k not in frozenset_first_import_structure:
frozenset_first_import_structure[r_k] = {}
if _key not in frozenset_first_import_structure[r_k]:
frozenset_first_import_structure[r_k][_key] = {}
# _key is a frozenset -> we switch around the r_k and _key
frozenset_first_import_structure[r_k][_key].update(r_v)
else:
if _key not in frozenset_first_import_structure:
frozenset_first_import_structure[_key] = {}
if r_k not in frozenset_first_import_structure[_key]:
frozenset_first_import_structure[_key][r_k] = {}
# _key is not a frozenset -> we keep the order of r_k and _key
frozenset_first_import_structure[_key][r_k].update(r_v)
else:
tuple_first_import_structure[_key] = propagate_frozenset(_value)
frozenset_first_import_structure[_key] = propagate_frozenset(_value)
return tuple_first_import_structure
return frozenset_first_import_structure
def flatten_dict(_dict, previous_key=None):
items = []

View File

@ -0,0 +1,92 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires(backends=("torch>=2.5",))
class D0:
def __init__(self):
pass
@requires(backends=("torch>=2.5",))
def d0():
pass
@requires(backends=("torch>2.5",))
class D1:
def __init__(self):
pass
@requires(backends=("torch>2.5",))
def d1():
pass
@requires(backends=("torch<=2.5",))
class D2:
def __init__(self):
pass
@requires(backends=("torch<=2.5",))
def d2():
pass
@requires(backends=("torch<2.5",))
class D3:
def __init__(self):
pass
@requires(backends=("torch<2.5",))
def d3():
pass
@requires(backends=("torch==2.5",))
class D4:
def __init__(self):
pass
@requires(backends=("torch==2.5",))
def d4():
pass
@requires(backends=("torch!=2.5",))
class D5:
def __init__(self):
pass
@requires(backends=("torch!=2.5",))
def d5():
pass
@requires(backends=("torch>=2.5", "accelerate<0.20"))
class D6:
def __init__(self):
pass
@requires(backends=("torch>=2.5", "accelerate<0.20"))
def d6():
pass

View File

@ -1,11 +1,19 @@
import os
import unittest
from pathlib import Path
from typing import Callable
from transformers.utils.import_utils import define_import_structure, spread_import_structure
import pytest
from transformers.utils.import_utils import (
Backend,
VersionComparison,
define_import_structure,
spread_import_structure,
)
import_structures = Path("import_structures")
import_structures = Path(__file__).parent / "import_structures"
def fetch__all__(file_content):
@ -36,26 +44,39 @@ class TestImportStructures(unittest.TestCase):
models_path = base_transformers_path / "src" / "transformers" / "models"
models_import_structure = spread_import_structure(define_import_structure(models_path))
# TODO: Lysandre
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
@unittest.skip(reason="failing")
def test_definition(self):
import_structure = define_import_structure(import_structures)
import_structure_definition = {
frozenset(()): {
"import_structure_raw_register": {"A0", "a0", "A4"},
valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
frozenset(): {
"import_structure_raw_register": {"A0", "A4", "a0"},
"import_structure_register_with_comments": {"B0", "b0"},
},
frozenset(("tf", "torch")): {
"import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"},
"import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"},
frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
frozenset({"torch"}): {
"import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}
},
frozenset(("torch",)): {
"import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"},
frozenset({"tf", "torch"}): {
"import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
"import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
},
frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
frozenset({"torch>=2.5", "accelerate<0.20"}): {
"import_structure_raw_register_with_versions": {"D6", "d6"}
},
}
self.assertDictEqual(import_structure, import_structure_definition)
self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
for _frozenset in valid_frozensets.keys():
self.assertTrue(_frozenset in import_structure)
self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys()))
for module, objects in valid_frozensets[_frozenset].items():
self.assertTrue(module in import_structure[_frozenset])
self.assertSetEqual(objects, import_structure[_frozenset][module])
def test_transformers_specific_model_import(self):
"""
@ -96,9 +117,92 @@ class TestImportStructures(unittest.TestCase):
)
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
# TODO: Lysandre
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
@unittest.skip(reason="failing")
def test_export_backend_should_be_defined(self):
with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"):
pass
def test_import_spread(self):
"""
This test is specifically designed to test that varying levels of depth across import structures are
respected.
In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
- models.{frozensets}
- models.albert.{frozensets}
- models.deprecated.transfo_xl.{frozensets}
"""
initial_import_structure = {
frozenset(): {"dummy_non_model": {"DummyObject"}},
"models": {
frozenset(): {"dummy_config": {"DummyConfig"}},
"albert": {
frozenset(): {"configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}},
frozenset({"torch"}): {
"modeling_albert": {
"AlbertForMaskedLM",
}
},
},
"llama": {
frozenset(): {"configuration_llama": {"LlamaConfig"}},
frozenset({"torch"}): {
"modeling_llama": {
"LlamaForCausalLM",
}
},
},
"deprecated": {
"transfo_xl": {
frozenset({"torch"}): {
"modeling_transfo_xl": {
"TransfoXLModel",
}
},
frozenset(): {
"configuration_transfo_xl": {"TransfoXLConfig"},
"tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
},
},
"deta": {
frozenset({"torch"}): {
"modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
},
frozenset(): {"configuration_deta": {"DetaConfig"}},
frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
},
},
},
}
ground_truth_spread_import_structure = {
frozenset(): {
"dummy_non_model": {"DummyObject"},
"models.dummy_config": {"DummyConfig"},
"models.albert.configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"},
"models.llama.configuration_llama": {"LlamaConfig"},
"models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
"models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
"models.deprecated.deta.configuration_deta": {"DetaConfig"},
},
frozenset({"torch"}): {
"models.albert.modeling_albert": {"AlbertForMaskedLM"},
"models.llama.modeling_llama": {"LlamaForCausalLM"},
"models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
"models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
},
frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
}
newly_spread_import_structure = spread_import_structure(initial_import_structure)
self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)
@pytest.mark.parametrize(
"backend,package_name,version_comparison,version",
[
pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"),
pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"),
pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"),
],
)
def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str):
assert backend.package_name == package_name
assert VersionComparison.from_string(backend.version_comparison) == version_comparison
assert backend.version == version