mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Protect ParallelInterface
This commit is contained in:
parent
f4ef41c45e
commit
cb513e35f9
4
setup.py
4
setup.py
@ -125,7 +125,7 @@ _deps = [
|
||||
"jaxlib>=0.4.1,<=0.4.13",
|
||||
"jieba",
|
||||
"jinja2>=3.1.0",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
@ -315,7 +315,7 @@ extras["audio"] = deps_list(
|
||||
"librosa",
|
||||
"pyctcdecode",
|
||||
"phonemizer",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
)
|
||||
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
|
||||
extras["speech"] = deps_list("torchaudio") + extras["audio"]
|
||||
|
@ -32,7 +32,7 @@ deps = {
|
||||
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
|
||||
"jieba": "jieba",
|
||||
"jinja2": "jinja2>=3.1.0",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"kernels": "kernels>=0.4.4,<0.5",
|
||||
|
@ -729,7 +729,11 @@ class ParallelInterface(MutableMapping):
|
||||
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
_global_mapping = {
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
ParallelInterface._global_mapping = {
|
||||
"colwise": ColwiseParallel(),
|
||||
"rowwise": RowwiseParallel(),
|
||||
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
|
||||
@ -743,9 +747,6 @@ class ParallelInterface(MutableMapping):
|
||||
"replicate": ReplicateParallel(),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
@ -775,7 +776,11 @@ class ParallelInterface(MutableMapping):
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
|
||||
else:
|
||||
ALL_PARALLEL_STYLES = None
|
||||
|
||||
|
||||
def convert_local_tensor_to_dtensor(
|
||||
|
Loading…
Reference in New Issue
Block a user