mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix no split modules underlying modules (#27090)
* fix no split * style * remove comm * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * rename modules --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
66b088faf0
commit
5be1fb6d1f
@ -1520,21 +1520,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
"""
|
||||
if self._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
_no_split_modules = set(self._no_split_modules)
|
||||
for module in self.modules():
|
||||
if isinstance(module, PreTrainedModel):
|
||||
if module._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
else:
|
||||
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
||||
_no_split_modules = set()
|
||||
modules_to_check = [self]
|
||||
while len(modules_to_check) > 0:
|
||||
module = modules_to_check.pop(-1)
|
||||
# if the module does not appear in _no_split_modules, we also check the children
|
||||
if module.__class__.__name__ not in _no_split_modules:
|
||||
if isinstance(module, PreTrainedModel):
|
||||
if module._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
else:
|
||||
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
||||
modules_to_check += list(module.children())
|
||||
return list(_no_split_modules)
|
||||
|
||||
def resize_token_embeddings(
|
||||
|
@ -2641,6 +2641,7 @@ class SeamlessM4THifiGan(nn.Module):
|
||||
class SeamlessM4TCodeHifiGan(PreTrainedModel):
|
||||
config_class = SeamlessM4TConfig
|
||||
main_input_name = "input_embeds"
|
||||
_no_split_modules = []
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
Loading…
Reference in New Issue
Block a user