Pass eps to Mistral3RMSNorm (#38026)

Pass eps to Mistral3RMSNorm

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sergio Paniego Blanco 2025-05-19 15:09:25 +02:00 committed by GitHub
parent 6c6302817d
commit 47f8578d96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -104,7 +104,7 @@ class Mistral3PatchMerger(nn.Module):
class Mistral3MultiModalProjector(nn.Module):
def __init__(self, config: Mistral3Config):
super().__init__()
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
self.patch_merger = Mistral3PatchMerger(config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)

View File

@ -82,7 +82,7 @@ class Mistral3PatchMerger(nn.Module):
class Mistral3MultiModalProjector(nn.Module):
def __init__(self, config: Mistral3Config):
super().__init__()
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
self.patch_merger = Mistral3PatchMerger(config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)