[mllama] fix loading and inference (#38223)

fix loading
This commit is contained in:
Raushan Turganbay 2025-05-20 17:34:55 +02:00 committed by GitHub
parent 390f153469
commit 2edb0e4b4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -486,8 +486,6 @@ class MllamaTextCrossAttention(nn.Module):
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
@ -850,7 +848,7 @@ class MllamaRotaryEmbedding(nn.Module):
@auto_docstring
class MllamaPreTrainedModel(PreTrainedModel):
config_class = MllamaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = [
"MllamaVisionEncoderLayer",