Add fuyu device map (#26949)

* add _no_split_modules

* style

* fix _no_split_modules

* add doc
This commit is contained in:
Marc Sun 2023-10-24 15:10:23 +02:00 committed by GitHub
parent b18e31407c
commit 41496b95da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 6 deletions

View File

@ -1507,6 +1507,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
device_map (`str`):
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
Returns:
`List[str]`: List of modules that should not be split
"""
if self._no_split_modules is None:
raise ValueError(
f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
_no_split_modules = set(self._no_split_modules)
for module in self.modules():
if isinstance(module, PreTrainedModel):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
return list(_no_split_modules)
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> nn.Embedding:
@ -3226,12 +3255,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif load_in_8bit:
target_dtype = torch.int8
if model._no_split_modules is None:
raise ValueError(
f"{model.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
no_split_modules = model._no_split_modules
no_split_modules = model._get_no_split_modules(device_map)
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "

View File

@ -262,6 +262,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
patch_embeddings = patch_embeddings.to(inputs_embeds.device)
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,