mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add fuyu device map (#26949)
* add _no_split_modules * style * fix _no_split_modules * add doc
This commit is contained in:
parent
b18e31407c
commit
41496b95da
@ -1507,6 +1507,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
"""
|
||||
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
||||
get the underlying `_no_split_modules`.
|
||||
|
||||
Args:
|
||||
device_map (`str`):
|
||||
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
||||
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
"""
|
||||
if self._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
_no_split_modules = set(self._no_split_modules)
|
||||
for module in self.modules():
|
||||
if isinstance(module, PreTrainedModel):
|
||||
if module._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
else:
|
||||
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
||||
return list(_no_split_modules)
|
||||
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
||||
) -> nn.Embedding:
|
||||
@ -3226,12 +3255,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif load_in_8bit:
|
||||
target_dtype = torch.int8
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{model.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
no_split_modules = model._no_split_modules
|
||||
no_split_modules = model._get_no_split_modules(device_map)
|
||||
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
raise ValueError(
|
||||
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
||||
|
@ -262,6 +262,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device)
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
|
Loading…
Reference in New Issue
Block a user