fix fuyu device_map compatibility (#29880)

fix foward
This commit is contained in:
Marc Sun 2024-03-27 10:18:48 +01:00 committed by GitHub
parent 4d8427f739
commit 31c575bcf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -290,7 +290,9 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
.squeeze(0)
.to(inputs_embeds.device)
for patch in image_patches
]
inputs_embeds = self.gather_continuous_embeddings(