Fix no split modules underlying modules (#27090)

* fix no split

* style

* remove comm

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* rename modules

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Marc Sun 2023-10-27 15:49:20 +02:00 committed by GitHub
parent 66b088faf0
commit 5be1fb6d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 15 deletions

View File

@ -1520,21 +1520,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
`List[str]`: List of modules that should not be split
"""
if self._no_split_modules is None:
raise ValueError(
f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
_no_split_modules = set(self._no_split_modules)
for module in self.modules():
if isinstance(module, PreTrainedModel):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
_no_split_modules = set()
modules_to_check = [self]
while len(modules_to_check) > 0:
module = modules_to_check.pop(-1)
# if the module does not appear in _no_split_modules, we also check the children
if module.__class__.__name__ not in _no_split_modules:
if isinstance(module, PreTrainedModel):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
modules_to_check += list(module.children())
return list(_no_split_modules)
def resize_token_embeddings(

View File

@ -2641,6 +2641,7 @@ class SeamlessM4THifiGan(nn.Module):
class SeamlessM4TCodeHifiGan(PreTrainedModel):
config_class = SeamlessM4TConfig
main_input_name = "input_embeds"
_no_split_modules = []
def __init__(self, config):
super().__init__(config)