From 47f8578d96a2f814084d7c04b480bcf894b8101e Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 19 May 2025 15:09:25 +0200 Subject: [PATCH] Pass `eps` to `Mistral3RMSNorm` (#38026) Pass eps to Mistral3RMSNorm Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/mistral3/modeling_mistral3.py | 2 +- src/transformers/models/mistral3/modular_mistral3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 5f0c9604760..a74a4663fec 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -104,7 +104,7 @@ class Mistral3PatchMerger(nn.Module): class Mistral3MultiModalProjector(nn.Module): def __init__(self, config: Mistral3Config): super().__init__() - self.norm = Mistral3RMSNorm(config.vision_config.hidden_size) + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps) self.patch_merger = Mistral3PatchMerger(config) # We have hidden_size * the number of vision feature layers num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index a3b5fa5ecab..69232f1ec86 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -82,7 +82,7 @@ class Mistral3PatchMerger(nn.Module): class Mistral3MultiModalProjector(nn.Module): def __init__(self, config: Mistral3Config): super().__init__() - self.norm = Mistral3RMSNorm(config.vision_config.hidden_size) + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps) self.patch_merger = Mistral3PatchMerger(config) # We have hidden_size * the number of vision feature layers num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)