mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Support for version spec in requires & arbitrary mismatching depths across folders (#37854)
* Support for version spec in requires & arbitrary mismatching depths * Quality * Testing
This commit is contained in:
parent
774dc274ac
commit
23d79cea75
@ -84,6 +84,19 @@ class Trainer:
|
||||
|
||||
Backends that can be added here are all the backends that are available in the `import_utils.py` module.
|
||||
|
||||
Additionally, specific versions can be specified in each backend. For example, this is how you would specify
|
||||
a requirement on torch>=2.6 on the `Trainer` class:
|
||||
|
||||
```python
|
||||
from .utils.import_utils import requires
|
||||
|
||||
@requires(backends=("torch>=2.6", "accelerate"))
|
||||
class Trainer:
|
||||
...
|
||||
```
|
||||
|
||||
You can specify the following operators: `==`, `>`, `>=`, `<`, `<=`, `!=`.
|
||||
|
||||
## Methods
|
||||
|
||||
[[autodoc]] utils.import_utils.define_import_structure
|
||||
|
@ -19,16 +19,19 @@ import importlib.machinery
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import json
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -1838,8 +1841,16 @@ def requires_backends(obj, backends):
|
||||
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
|
||||
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
failed = []
|
||||
for backend in backends:
|
||||
if isinstance(backend, Backend):
|
||||
available, msg = backend.is_satisfied, backend.error_message
|
||||
else:
|
||||
available, msg = BACKENDS_MAPPING[backend]
|
||||
|
||||
if not available():
|
||||
failed.append(msg.format(name))
|
||||
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
@ -1884,10 +1895,13 @@ class _LazyModule(ModuleType):
|
||||
import_structure: IMPORT_STRUCTURE_T,
|
||||
module_spec: Optional[importlib.machinery.ModuleSpec] = None,
|
||||
extra_objects: Optional[Dict[str, object]] = None,
|
||||
explicit_import_shortcut: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__(name)
|
||||
|
||||
self._object_missing_backend = {}
|
||||
self._explicit_import_shortcut = explicit_import_shortcut if explicit_import_shortcut else {}
|
||||
|
||||
if any(isinstance(key, frozenset) for key in import_structure.keys()):
|
||||
self._modules = set()
|
||||
self._class_to_module = {}
|
||||
@ -1916,14 +1930,25 @@ class _LazyModule(ModuleType):
|
||||
module_keys = set(
|
||||
chain(*[[k.rsplit(".", i)[0] for i in range(k.count(".") + 1)] for k in list(module.keys())])
|
||||
)
|
||||
|
||||
for backend in backends:
|
||||
if backend not in BACKENDS_MAPPING:
|
||||
raise ValueError(
|
||||
f"Error: the following backend: '{backend}' was specified around object {module} but isn't specified in the backends mapping."
|
||||
)
|
||||
callable, error = BACKENDS_MAPPING[backend]
|
||||
if not callable():
|
||||
if backend in BACKENDS_MAPPING:
|
||||
callable, _ = BACKENDS_MAPPING[backend]
|
||||
else:
|
||||
if any(key in backend for key in ["=", "<", ">"]):
|
||||
backend = Backend(backend)
|
||||
callable = backend.is_satisfied
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not callable():
|
||||
missing_backends.append(backend)
|
||||
except (importlib.metadata.PackageNotFoundError, ModuleNotFoundError, RuntimeError):
|
||||
missing_backends.append(backend)
|
||||
|
||||
self._modules = self._modules.union(module_keys)
|
||||
|
||||
for key, values in module.items():
|
||||
@ -2000,12 +2025,29 @@ class _LazyModule(ModuleType):
|
||||
|
||||
value = Placeholder
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
try:
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
except (ModuleNotFoundError, RuntimeError) as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Could not import module '{name}'. Are this object's requirements defined correctly?"
|
||||
) from e
|
||||
|
||||
elif name in self._modules:
|
||||
value = self._get_module(name)
|
||||
try:
|
||||
value = self._get_module(name)
|
||||
except (ModuleNotFoundError, RuntimeError) as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Could not import module '{name}'. Are this object's requirements defined correctly?"
|
||||
) from e
|
||||
else:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
value = None
|
||||
for key, values in self._explicit_import_shortcut.items():
|
||||
if name in values:
|
||||
value = self._get_module(key)
|
||||
|
||||
if value is None:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
@ -2046,6 +2088,64 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
|
||||
return module
|
||||
|
||||
|
||||
class VersionComparison(Enum):
|
||||
EQUAL = operator.eq
|
||||
NOT_EQUAL = operator.ne
|
||||
GREATER_THAN = operator.gt
|
||||
LESS_THAN = operator.lt
|
||||
GREATER_THAN_OR_EQUAL = operator.ge
|
||||
LESS_THAN_OR_EQUAL = operator.le
|
||||
|
||||
@staticmethod
|
||||
def from_string(version_string: str) -> "VersionComparison":
|
||||
string_to_operator = {
|
||||
"=": VersionComparison.EQUAL.value,
|
||||
"==": VersionComparison.EQUAL.value,
|
||||
"!=": VersionComparison.NOT_EQUAL.value,
|
||||
">": VersionComparison.GREATER_THAN.value,
|
||||
"<": VersionComparison.LESS_THAN.value,
|
||||
">=": VersionComparison.GREATER_THAN_OR_EQUAL.value,
|
||||
"<=": VersionComparison.LESS_THAN_OR_EQUAL.value,
|
||||
}
|
||||
|
||||
return string_to_operator[version_string]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def split_package_version(package_version_str) -> Tuple[str, str, str]:
|
||||
pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
|
||||
match = re.match(pattern, package_version_str)
|
||||
if match:
|
||||
return (match.group(1), match.group(2), match.group(3))
|
||||
else:
|
||||
raise ValueError(f"Invalid package version string: {package_version_str}")
|
||||
|
||||
|
||||
class Backend:
|
||||
def __init__(self, backend_requirement: str):
|
||||
self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement)
|
||||
|
||||
if self.package_name not in BACKENDS_MAPPING:
|
||||
raise ValueError(
|
||||
f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}"
|
||||
)
|
||||
|
||||
def is_satisfied(self) -> bool:
|
||||
return VersionComparison.from_string(self.version_comparison)(
|
||||
version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")'
|
||||
|
||||
@property
|
||||
def error_message(self):
|
||||
return (
|
||||
f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That"
|
||||
f" library was not found with this version in your environment."
|
||||
)
|
||||
|
||||
|
||||
def requires(*, backends=()):
|
||||
"""
|
||||
This decorator enables two things:
|
||||
@ -2053,15 +2153,22 @@ def requires(*, backends=()):
|
||||
to execute correctly without instantiating it
|
||||
- The '@requires' string is used to dynamically import objects
|
||||
"""
|
||||
for backend in backends:
|
||||
if backend not in BACKENDS_MAPPING:
|
||||
raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}")
|
||||
|
||||
if not isinstance(backends, tuple):
|
||||
raise ValueError("Backends should be a tuple.")
|
||||
|
||||
applied_backends = []
|
||||
for backend in backends:
|
||||
if backend in BACKENDS_MAPPING:
|
||||
applied_backends.append(backend)
|
||||
else:
|
||||
if any(key in backend for key in ["=", "<", ">"]):
|
||||
applied_backends.append(Backend(backend))
|
||||
else:
|
||||
raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}")
|
||||
|
||||
def inner_fn(fun):
|
||||
fun.__backends = backends
|
||||
fun.__backends = applied_backends
|
||||
return fun
|
||||
|
||||
return inner_fn
|
||||
@ -2369,23 +2476,53 @@ def spread_import_structure(nested_import_structure):
|
||||
"""
|
||||
|
||||
def propagate_frozenset(unordered_import_structure):
|
||||
tuple_first_import_structure = {}
|
||||
frozenset_first_import_structure = {}
|
||||
for _key, _value in unordered_import_structure.items():
|
||||
# If the value is not a dict but a string, no need for custom manipulation
|
||||
if not isinstance(_value, dict):
|
||||
tuple_first_import_structure[_key] = _value
|
||||
frozenset_first_import_structure[_key] = _value
|
||||
|
||||
elif any(isinstance(v, frozenset) for v in _value.keys()):
|
||||
# Here we want to switch around key and v
|
||||
for k, v in _value.items():
|
||||
if isinstance(k, frozenset):
|
||||
if k not in tuple_first_import_structure:
|
||||
tuple_first_import_structure[k] = {}
|
||||
tuple_first_import_structure[k][_key] = v
|
||||
# Here we want to switch around _key and k to propagate k upstream if it is a frozenset
|
||||
if k not in frozenset_first_import_structure:
|
||||
frozenset_first_import_structure[k] = {}
|
||||
if _key not in frozenset_first_import_structure[k]:
|
||||
frozenset_first_import_structure[k][_key] = {}
|
||||
|
||||
frozenset_first_import_structure[k][_key].update(v)
|
||||
|
||||
else:
|
||||
# If k is not a frozenset, it means that the dictionary is not "level": some keys (top-level)
|
||||
# are frozensets, whereas some are not -> frozenset keys are at an unkown depth-level of the
|
||||
# dictionary.
|
||||
#
|
||||
# We recursively propagate the frozenset for this specific dictionary so that the frozensets
|
||||
# are at the top-level when we handle them.
|
||||
propagated_frozenset = propagate_frozenset({k: v})
|
||||
for r_k, r_v in propagated_frozenset.items():
|
||||
if isinstance(_key, frozenset):
|
||||
if r_k not in frozenset_first_import_structure:
|
||||
frozenset_first_import_structure[r_k] = {}
|
||||
if _key not in frozenset_first_import_structure[r_k]:
|
||||
frozenset_first_import_structure[r_k][_key] = {}
|
||||
|
||||
# _key is a frozenset -> we switch around the r_k and _key
|
||||
frozenset_first_import_structure[r_k][_key].update(r_v)
|
||||
else:
|
||||
if _key not in frozenset_first_import_structure:
|
||||
frozenset_first_import_structure[_key] = {}
|
||||
if r_k not in frozenset_first_import_structure[_key]:
|
||||
frozenset_first_import_structure[_key][r_k] = {}
|
||||
|
||||
# _key is not a frozenset -> we keep the order of r_k and _key
|
||||
frozenset_first_import_structure[_key][r_k].update(r_v)
|
||||
|
||||
else:
|
||||
tuple_first_import_structure[_key] = propagate_frozenset(_value)
|
||||
frozenset_first_import_structure[_key] = propagate_frozenset(_value)
|
||||
|
||||
return tuple_first_import_structure
|
||||
return frozenset_first_import_structure
|
||||
|
||||
def flatten_dict(_dict, previous_key=None):
|
||||
items = []
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5",))
|
||||
class D0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5",))
|
||||
def d0():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>2.5",))
|
||||
class D1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>2.5",))
|
||||
def d1():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<=2.5",))
|
||||
class D2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<=2.5",))
|
||||
def d2():
|
||||
pass
|
||||
|
||||
@requires(backends=("torch<2.5",))
|
||||
class D3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<2.5",))
|
||||
def d3():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch==2.5",))
|
||||
class D4:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch==2.5",))
|
||||
def d4():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch!=2.5",))
|
||||
class D5:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch!=2.5",))
|
||||
def d5():
|
||||
pass
|
||||
|
||||
@requires(backends=("torch>=2.5", "accelerate<0.20"))
|
||||
class D6:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5", "accelerate<0.20"))
|
||||
def d6():
|
||||
pass
|
@ -1,11 +1,19 @@
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from transformers.utils.import_utils import define_import_structure, spread_import_structure
|
||||
import pytest
|
||||
|
||||
from transformers.utils.import_utils import (
|
||||
Backend,
|
||||
VersionComparison,
|
||||
define_import_structure,
|
||||
spread_import_structure,
|
||||
)
|
||||
|
||||
|
||||
import_structures = Path("import_structures")
|
||||
import_structures = Path(__file__).parent / "import_structures"
|
||||
|
||||
|
||||
def fetch__all__(file_content):
|
||||
@ -36,26 +44,39 @@ class TestImportStructures(unittest.TestCase):
|
||||
models_path = base_transformers_path / "src" / "transformers" / "models"
|
||||
models_import_structure = spread_import_structure(define_import_structure(models_path))
|
||||
|
||||
# TODO: Lysandre
|
||||
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
|
||||
@unittest.skip(reason="failing")
|
||||
def test_definition(self):
|
||||
import_structure = define_import_structure(import_structures)
|
||||
import_structure_definition = {
|
||||
frozenset(()): {
|
||||
"import_structure_raw_register": {"A0", "a0", "A4"},
|
||||
valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
|
||||
frozenset(): {
|
||||
"import_structure_raw_register": {"A0", "A4", "a0"},
|
||||
"import_structure_register_with_comments": {"B0", "b0"},
|
||||
},
|
||||
frozenset(("tf", "torch")): {
|
||||
"import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"},
|
||||
"import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"},
|
||||
frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
|
||||
frozenset({"torch"}): {
|
||||
"import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}
|
||||
},
|
||||
frozenset(("torch",)): {
|
||||
"import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"},
|
||||
frozenset({"tf", "torch"}): {
|
||||
"import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
|
||||
"import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
|
||||
},
|
||||
frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
|
||||
frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
|
||||
frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
|
||||
frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
|
||||
frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
|
||||
frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
|
||||
frozenset({"torch>=2.5", "accelerate<0.20"}): {
|
||||
"import_structure_raw_register_with_versions": {"D6", "d6"}
|
||||
},
|
||||
}
|
||||
|
||||
self.assertDictEqual(import_structure, import_structure_definition)
|
||||
self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
|
||||
for _frozenset in valid_frozensets.keys():
|
||||
self.assertTrue(_frozenset in import_structure)
|
||||
self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys()))
|
||||
for module, objects in valid_frozensets[_frozenset].items():
|
||||
self.assertTrue(module in import_structure[_frozenset])
|
||||
self.assertSetEqual(objects, import_structure[_frozenset][module])
|
||||
|
||||
def test_transformers_specific_model_import(self):
|
||||
"""
|
||||
@ -96,9 +117,92 @@ class TestImportStructures(unittest.TestCase):
|
||||
)
|
||||
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
|
||||
|
||||
# TODO: Lysandre
|
||||
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
|
||||
@unittest.skip(reason="failing")
|
||||
def test_export_backend_should_be_defined(self):
|
||||
with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"):
|
||||
pass
|
||||
def test_import_spread(self):
|
||||
"""
|
||||
This test is specifically designed to test that varying levels of depth across import structures are
|
||||
respected.
|
||||
|
||||
In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
|
||||
- models.{frozensets}
|
||||
- models.albert.{frozensets}
|
||||
- models.deprecated.transfo_xl.{frozensets}
|
||||
"""
|
||||
initial_import_structure = {
|
||||
frozenset(): {"dummy_non_model": {"DummyObject"}},
|
||||
"models": {
|
||||
frozenset(): {"dummy_config": {"DummyConfig"}},
|
||||
"albert": {
|
||||
frozenset(): {"configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}},
|
||||
frozenset({"torch"}): {
|
||||
"modeling_albert": {
|
||||
"AlbertForMaskedLM",
|
||||
}
|
||||
},
|
||||
},
|
||||
"llama": {
|
||||
frozenset(): {"configuration_llama": {"LlamaConfig"}},
|
||||
frozenset({"torch"}): {
|
||||
"modeling_llama": {
|
||||
"LlamaForCausalLM",
|
||||
}
|
||||
},
|
||||
},
|
||||
"deprecated": {
|
||||
"transfo_xl": {
|
||||
frozenset({"torch"}): {
|
||||
"modeling_transfo_xl": {
|
||||
"TransfoXLModel",
|
||||
}
|
||||
},
|
||||
frozenset(): {
|
||||
"configuration_transfo_xl": {"TransfoXLConfig"},
|
||||
"tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
|
||||
},
|
||||
},
|
||||
"deta": {
|
||||
frozenset({"torch"}): {
|
||||
"modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
|
||||
},
|
||||
frozenset(): {"configuration_deta": {"DetaConfig"}},
|
||||
frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ground_truth_spread_import_structure = {
|
||||
frozenset(): {
|
||||
"dummy_non_model": {"DummyObject"},
|
||||
"models.dummy_config": {"DummyConfig"},
|
||||
"models.albert.configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"},
|
||||
"models.llama.configuration_llama": {"LlamaConfig"},
|
||||
"models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
|
||||
"models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
|
||||
"models.deprecated.deta.configuration_deta": {"DetaConfig"},
|
||||
},
|
||||
frozenset({"torch"}): {
|
||||
"models.albert.modeling_albert": {"AlbertForMaskedLM"},
|
||||
"models.llama.modeling_llama": {"LlamaForCausalLM"},
|
||||
"models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
|
||||
"models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
|
||||
},
|
||||
frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
|
||||
}
|
||||
|
||||
newly_spread_import_structure = spread_import_structure(initial_import_structure)
|
||||
|
||||
self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend,package_name,version_comparison,version",
|
||||
[
|
||||
pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"),
|
||||
pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"),
|
||||
pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"),
|
||||
],
|
||||
)
|
||||
def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str):
|
||||
assert backend.package_name == package_name
|
||||
assert VersionComparison.from_string(backend.version_comparison) == version_comparison
|
||||
assert backend.version == version
|
||||
|
Loading…
Reference in New Issue
Block a user