From 41496b95da5d34368a13538548c85527377972a6 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:10:23 +0200 Subject: [PATCH] Add fuyu device map (#26949) * add _no_split_modules * style * fix _no_split_modules * add doc --- src/transformers/modeling_utils.py | 36 +++++++++++++++---- src/transformers/models/fuyu/modeling_fuyu.py | 1 + 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0317695f209..cceb064bb44 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1507,6 +1507,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + if self._no_split_modules is None: + raise ValueError( + f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + _no_split_modules = set(self._no_split_modules) + for module in self.modules(): + if isinstance(module, PreTrainedModel): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + return list(_no_split_modules) + def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None ) -> nn.Embedding: @@ -3226,12 +3255,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif load_in_8bit: target_dtype = torch.int8 - if model._no_split_modules is None: - raise ValueError( - f"{model.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " - "class needs to implement the `_no_split_modules` attribute." - ) - no_split_modules = model._no_split_modules + no_split_modules = model._get_no_split_modules(device_map) if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: raise ValueError( "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index b14b1b0b871..03312420ca6 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -262,6 +262,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype)) + patch_embeddings = patch_embeddings.to(inputs_embeds.device) inputs_embeds = self.gather_continuous_embeddings( word_embeddings=inputs_embeds, continuous_embeddings=patch_embeddings,