diff --git a/docs/source/en/internal/import_utils.md b/docs/source/en/internal/import_utils.md index 93daa2ced3a..749ece15da6 100644 --- a/docs/source/en/internal/import_utils.md +++ b/docs/source/en/internal/import_utils.md @@ -84,6 +84,19 @@ class Trainer: Backends that can be added here are all the backends that are available in the `import_utils.py` module. +Additionally, specific versions can be specified in each backend. For example, this is how you would specify +a requirement on torch>=2.6 on the `Trainer` class: + +```python +from .utils.import_utils import requires + +@requires(backends=("torch>=2.6", "accelerate")) +class Trainer: + ... +``` + +You can specify the following operators: `==`, `>`, `>=`, `<`, `<=`, `!=`. + ## Methods [[autodoc]] utils.import_utils.define_import_structure diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6aee59d44f6..3c0079459b1 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -19,16 +19,19 @@ import importlib.machinery import importlib.metadata import importlib.util import json +import operator import os +import re import shutil import subprocess import sys import warnings from collections import OrderedDict +from enum import Enum from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Union from packaging import version @@ -1838,8 +1841,16 @@ def requires_backends(obj, backends): if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] + failed = [] + for backend in backends: + if isinstance(backend, Backend): + available, msg = backend.is_satisfied, backend.error_message + else: + available, msg = BACKENDS_MAPPING[backend] + + if not available(): + failed.append(msg.format(name)) + if failed: raise ImportError("".join(failed)) @@ -1884,10 +1895,13 @@ class _LazyModule(ModuleType): import_structure: IMPORT_STRUCTURE_T, module_spec: Optional[importlib.machinery.ModuleSpec] = None, extra_objects: Optional[Dict[str, object]] = None, + explicit_import_shortcut: Optional[Dict[str, List[str]]] = None, ): super().__init__(name) self._object_missing_backend = {} + self._explicit_import_shortcut = explicit_import_shortcut if explicit_import_shortcut else {} + if any(isinstance(key, frozenset) for key in import_structure.keys()): self._modules = set() self._class_to_module = {} @@ -1916,14 +1930,25 @@ class _LazyModule(ModuleType): module_keys = set( chain(*[[k.rsplit(".", i)[0] for i in range(k.count(".") + 1)] for k in list(module.keys())]) ) + for backend in backends: - if backend not in BACKENDS_MAPPING: - raise ValueError( - f"Error: the following backend: '{backend}' was specified around object {module} but isn't specified in the backends mapping." - ) - callable, error = BACKENDS_MAPPING[backend] - if not callable(): + if backend in BACKENDS_MAPPING: + callable, _ = BACKENDS_MAPPING[backend] + else: + if any(key in backend for key in ["=", "<", ">"]): + backend = Backend(backend) + callable = backend.is_satisfied + else: + raise ValueError( + f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}" + ) + + try: + if not callable(): + missing_backends.append(backend) + except (importlib.metadata.PackageNotFoundError, ModuleNotFoundError, RuntimeError): missing_backends.append(backend) + self._modules = self._modules.union(module_keys) for key, values in module.items(): @@ -2000,12 +2025,29 @@ class _LazyModule(ModuleType): value = Placeholder elif name in self._class_to_module.keys(): - module = self._get_module(self._class_to_module[name]) - value = getattr(module, name) + try: + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + except (ModuleNotFoundError, RuntimeError) as e: + raise ModuleNotFoundError( + f"Could not import module '{name}'. Are this object's requirements defined correctly?" + ) from e + elif name in self._modules: - value = self._get_module(name) + try: + value = self._get_module(name) + except (ModuleNotFoundError, RuntimeError) as e: + raise ModuleNotFoundError( + f"Could not import module '{name}'. Are this object's requirements defined correctly?" + ) from e else: - raise AttributeError(f"module {self.__name__} has no attribute {name}") + value = None + for key, values in self._explicit_import_shortcut.items(): + if name in values: + value = self._get_module(key) + + if value is None: + raise AttributeError(f"module {self.__name__} has no attribute {name}") setattr(self, name, value) return value @@ -2046,6 +2088,64 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: return module +class VersionComparison(Enum): + EQUAL = operator.eq + NOT_EQUAL = operator.ne + GREATER_THAN = operator.gt + LESS_THAN = operator.lt + GREATER_THAN_OR_EQUAL = operator.ge + LESS_THAN_OR_EQUAL = operator.le + + @staticmethod + def from_string(version_string: str) -> "VersionComparison": + string_to_operator = { + "=": VersionComparison.EQUAL.value, + "==": VersionComparison.EQUAL.value, + "!=": VersionComparison.NOT_EQUAL.value, + ">": VersionComparison.GREATER_THAN.value, + "<": VersionComparison.LESS_THAN.value, + ">=": VersionComparison.GREATER_THAN_OR_EQUAL.value, + "<=": VersionComparison.LESS_THAN_OR_EQUAL.value, + } + + return string_to_operator[version_string] + + +@lru_cache() +def split_package_version(package_version_str) -> Tuple[str, str, str]: + pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)" + match = re.match(pattern, package_version_str) + if match: + return (match.group(1), match.group(2), match.group(3)) + else: + raise ValueError(f"Invalid package version string: {package_version_str}") + + +class Backend: + def __init__(self, backend_requirement: str): + self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement) + + if self.package_name not in BACKENDS_MAPPING: + raise ValueError( + f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}" + ) + + def is_satisfied(self) -> bool: + return VersionComparison.from_string(self.version_comparison)( + version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version) + ) + + def __repr__(self) -> str: + return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")' + + @property + def error_message(self): + return ( + f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That" + f" library was not found with this version in your environment." + ) + + def requires(*, backends=()): """ This decorator enables two things: @@ -2053,15 +2153,22 @@ def requires(*, backends=()): to execute correctly without instantiating it - The '@requires' string is used to dynamically import objects """ - for backend in backends: - if backend not in BACKENDS_MAPPING: - raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}") if not isinstance(backends, tuple): raise ValueError("Backends should be a tuple.") + applied_backends = [] + for backend in backends: + if backend in BACKENDS_MAPPING: + applied_backends.append(backend) + else: + if any(key in backend for key in ["=", "<", ">"]): + applied_backends.append(Backend(backend)) + else: + raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}") + def inner_fn(fun): - fun.__backends = backends + fun.__backends = applied_backends return fun return inner_fn @@ -2369,23 +2476,53 @@ def spread_import_structure(nested_import_structure): """ def propagate_frozenset(unordered_import_structure): - tuple_first_import_structure = {} + frozenset_first_import_structure = {} for _key, _value in unordered_import_structure.items(): + # If the value is not a dict but a string, no need for custom manipulation if not isinstance(_value, dict): - tuple_first_import_structure[_key] = _value + frozenset_first_import_structure[_key] = _value elif any(isinstance(v, frozenset) for v in _value.keys()): - # Here we want to switch around key and v for k, v in _value.items(): if isinstance(k, frozenset): - if k not in tuple_first_import_structure: - tuple_first_import_structure[k] = {} - tuple_first_import_structure[k][_key] = v + # Here we want to switch around _key and k to propagate k upstream if it is a frozenset + if k not in frozenset_first_import_structure: + frozenset_first_import_structure[k] = {} + if _key not in frozenset_first_import_structure[k]: + frozenset_first_import_structure[k][_key] = {} + + frozenset_first_import_structure[k][_key].update(v) + + else: + # If k is not a frozenset, it means that the dictionary is not "level": some keys (top-level) + # are frozensets, whereas some are not -> frozenset keys are at an unkown depth-level of the + # dictionary. + # + # We recursively propagate the frozenset for this specific dictionary so that the frozensets + # are at the top-level when we handle them. + propagated_frozenset = propagate_frozenset({k: v}) + for r_k, r_v in propagated_frozenset.items(): + if isinstance(_key, frozenset): + if r_k not in frozenset_first_import_structure: + frozenset_first_import_structure[r_k] = {} + if _key not in frozenset_first_import_structure[r_k]: + frozenset_first_import_structure[r_k][_key] = {} + + # _key is a frozenset -> we switch around the r_k and _key + frozenset_first_import_structure[r_k][_key].update(r_v) + else: + if _key not in frozenset_first_import_structure: + frozenset_first_import_structure[_key] = {} + if r_k not in frozenset_first_import_structure[_key]: + frozenset_first_import_structure[_key][r_k] = {} + + # _key is not a frozenset -> we keep the order of r_k and _key + frozenset_first_import_structure[_key][r_k].update(r_v) else: - tuple_first_import_structure[_key] = propagate_frozenset(_value) + frozenset_first_import_structure[_key] = propagate_frozenset(_value) - return tuple_first_import_structure + return frozenset_first_import_structure def flatten_dict(_dict, previous_key=None): items = [] diff --git a/tests/utils/import_structures/import_structure_raw_register_with_versions.py b/tests/utils/import_structures/import_structure_raw_register_with_versions.py new file mode 100644 index 00000000000..6d7c10e9793 --- /dev/null +++ b/tests/utils/import_structures/import_structure_raw_register_with_versions.py @@ -0,0 +1,92 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# fmt: off + +from transformers.utils.import_utils import requires + + +@requires(backends=("torch>=2.5",)) +class D0: + def __init__(self): + pass + + +@requires(backends=("torch>=2.5",)) +def d0(): + pass + + +@requires(backends=("torch>2.5",)) +class D1: + def __init__(self): + pass + + +@requires(backends=("torch>2.5",)) +def d1(): + pass + + +@requires(backends=("torch<=2.5",)) +class D2: + def __init__(self): + pass + + +@requires(backends=("torch<=2.5",)) +def d2(): + pass + +@requires(backends=("torch<2.5",)) +class D3: + def __init__(self): + pass + + +@requires(backends=("torch<2.5",)) +def d3(): + pass + + +@requires(backends=("torch==2.5",)) +class D4: + def __init__(self): + pass + + +@requires(backends=("torch==2.5",)) +def d4(): + pass + + +@requires(backends=("torch!=2.5",)) +class D5: + def __init__(self): + pass + + +@requires(backends=("torch!=2.5",)) +def d5(): + pass + +@requires(backends=("torch>=2.5", "accelerate<0.20")) +class D6: + def __init__(self): + pass + + +@requires(backends=("torch>=2.5", "accelerate<0.20")) +def d6(): + pass diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 0a9bf38fa40..87a90cae439 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -1,11 +1,19 @@ import os import unittest from pathlib import Path +from typing import Callable -from transformers.utils.import_utils import define_import_structure, spread_import_structure +import pytest + +from transformers.utils.import_utils import ( + Backend, + VersionComparison, + define_import_structure, + spread_import_structure, +) -import_structures = Path("import_structures") +import_structures = Path(__file__).parent / "import_structures" def fetch__all__(file_content): @@ -36,26 +44,39 @@ class TestImportStructures(unittest.TestCase): models_path = base_transformers_path / "src" / "transformers" / "models" models_import_structure = spread_import_structure(define_import_structure(models_path)) - # TODO: Lysandre - # See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests - @unittest.skip(reason="failing") def test_definition(self): import_structure = define_import_structure(import_structures) - import_structure_definition = { - frozenset(()): { - "import_structure_raw_register": {"A0", "a0", "A4"}, + valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = { + frozenset(): { + "import_structure_raw_register": {"A0", "A4", "a0"}, "import_structure_register_with_comments": {"B0", "b0"}, }, - frozenset(("tf", "torch")): { - "import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"}, - "import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"}, + frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}}, + frozenset({"torch"}): { + "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"} }, - frozenset(("torch",)): { - "import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"}, + frozenset({"tf", "torch"}): { + "import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"}, + "import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"}, + }, + frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}}, + frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}}, + frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}}, + frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}}, + frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}}, + frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}}, + frozenset({"torch>=2.5", "accelerate<0.20"}): { + "import_structure_raw_register_with_versions": {"D6", "d6"} }, } - self.assertDictEqual(import_structure, import_structure_definition) + self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys())) + for _frozenset in valid_frozensets.keys(): + self.assertTrue(_frozenset in import_structure) + self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys())) + for module, objects in valid_frozensets[_frozenset].items(): + self.assertTrue(module in import_structure[_frozenset]) + self.assertSetEqual(objects, import_structure[_frozenset][module]) def test_transformers_specific_model_import(self): """ @@ -96,9 +117,92 @@ class TestImportStructures(unittest.TestCase): ) self.assertListEqual(sorted(objects), sorted(_all), msg=error_message) - # TODO: Lysandre - # See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests - @unittest.skip(reason="failing") - def test_export_backend_should_be_defined(self): - with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"): - pass + def test_import_spread(self): + """ + This test is specifically designed to test that varying levels of depth across import structures are + respected. + + In this instance, frozensets are at respective depths of 1, 2 and 3, for example: + - models.{frozensets} + - models.albert.{frozensets} + - models.deprecated.transfo_xl.{frozensets} + """ + initial_import_structure = { + frozenset(): {"dummy_non_model": {"DummyObject"}}, + "models": { + frozenset(): {"dummy_config": {"DummyConfig"}}, + "albert": { + frozenset(): {"configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}}, + frozenset({"torch"}): { + "modeling_albert": { + "AlbertForMaskedLM", + } + }, + }, + "llama": { + frozenset(): {"configuration_llama": {"LlamaConfig"}}, + frozenset({"torch"}): { + "modeling_llama": { + "LlamaForCausalLM", + } + }, + }, + "deprecated": { + "transfo_xl": { + frozenset({"torch"}): { + "modeling_transfo_xl": { + "TransfoXLModel", + } + }, + frozenset(): { + "configuration_transfo_xl": {"TransfoXLConfig"}, + "tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"}, + }, + }, + "deta": { + frozenset({"torch"}): { + "modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"} + }, + frozenset(): {"configuration_deta": {"DetaConfig"}}, + frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}}, + }, + }, + }, + } + + ground_truth_spread_import_structure = { + frozenset(): { + "dummy_non_model": {"DummyObject"}, + "models.dummy_config": {"DummyConfig"}, + "models.albert.configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}, + "models.llama.configuration_llama": {"LlamaConfig"}, + "models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"}, + "models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"}, + "models.deprecated.deta.configuration_deta": {"DetaConfig"}, + }, + frozenset({"torch"}): { + "models.albert.modeling_albert": {"AlbertForMaskedLM"}, + "models.llama.modeling_llama": {"LlamaForCausalLM"}, + "models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"}, + "models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}, + }, + frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}}, + } + + newly_spread_import_structure = spread_import_structure(initial_import_structure) + + self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure) + + +@pytest.mark.parametrize( + "backend,package_name,version_comparison,version", + [ + pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"), + pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"), + pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"), + ], +) +def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str): + assert backend.package_name == package_name + assert VersionComparison.from_string(backend.version_comparison) == version_comparison + assert backend.version == version