diff --git a/setup.py b/setup.py index 9fe50073052..52024f77c12 100644 --- a/setup.py +++ b/setup.py @@ -125,7 +125,7 @@ _deps = [ "jaxlib>=0.4.1,<=0.4.13", "jieba", "jinja2>=3.1.0", - "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", + "kenlm", # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. @@ -315,7 +315,7 @@ extras["audio"] = deps_list( "librosa", "pyctcdecode", "phonemizer", - "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", + "kenlm", ) # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead extras["speech"] = deps_list("torchaudio") + extras["audio"] diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index dc2b37a1928..c01f5bb388c 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -32,7 +32,7 @@ deps = { "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jieba": "jieba", "jinja2": "jinja2>=3.1.0", - "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", + "kenlm": "kenlm", "keras": "keras>2.9,<2.16", "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", "kernels": "kernels>=0.4.4,<0.5", diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index ef868148b87..d07a768b01a 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -729,23 +729,24 @@ class ParallelInterface(MutableMapping): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given function) - _global_mapping = { - "colwise": ColwiseParallel(), - "rowwise": RowwiseParallel(), - "colwise_rep": ColwiseParallel(output_layouts=Replicate()), - "rowwise_rep": RowwiseParallel(input_layouts=Replicate()), - "local_colwise": ColwiseParallel(use_dtensor=False), - "local_rowwise": RowwiseParallel(use_dtensor=False), - "local": IsolatedParallel(), - "gather": GatherParallel(), - "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), - "sequence_parallel": SequenceParallel(), - "replicate": ReplicateParallel(), - } def __init__(self): self._local_mapping = {} + ParallelInterface._global_mapping = { + "colwise": ColwiseParallel(), + "rowwise": RowwiseParallel(), + "colwise_rep": ColwiseParallel(output_layouts=Replicate()), + "rowwise_rep": RowwiseParallel(input_layouts=Replicate()), + "local_colwise": ColwiseParallel(use_dtensor=False), + "local_rowwise": RowwiseParallel(use_dtensor=False), + "local": IsolatedParallel(), + "gather": GatherParallel(), + "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), + "sequence_parallel": SequenceParallel(), + "replicate": ReplicateParallel(), + } + def __getitem__(self, key): # First check if instance has a local override if key in self._local_mapping: @@ -775,7 +776,11 @@ class ParallelInterface(MutableMapping): # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones -ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() + +if is_torch_greater_or_equal("2.5") and _torch_distributed_available: + ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() +else: + ALL_PARALLEL_STYLES = None def convert_local_tensor_to_dtensor(