Protect ParallelInterface

This commit is contained in:
Lysandre 2025-05-20 18:26:11 +02:00
parent f4ef41c45e
commit cb513e35f9
3 changed files with 22 additions and 17 deletions

View File

@ -125,7 +125,7 @@ _deps = [
"jaxlib>=0.4.1,<=0.4.13",
"jieba",
"jinja2>=3.1.0",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm",
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
@ -315,7 +315,7 @@ extras["audio"] = deps_list(
"librosa",
"pyctcdecode",
"phonemizer",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm",
)
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]

View File

@ -32,7 +32,7 @@ deps = {
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
"jieba": "jieba",
"jinja2": "jinja2>=3.1.0",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
"kernels": "kernels>=0.4.4,<0.5",

View File

@ -729,7 +729,11 @@ class ParallelInterface(MutableMapping):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
def __init__(self):
self._local_mapping = {}
ParallelInterface._global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
@ -743,9 +747,6 @@ class ParallelInterface(MutableMapping):
"replicate": ReplicateParallel(),
}
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
@ -775,7 +776,11 @@ class ParallelInterface(MutableMapping):
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
else:
ALL_PARALLEL_STYLES = None
def convert_local_tensor_to_dtensor(