Detect and fix most _init_weights() issues - make it work for composite models (#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
This commit is contained in:
Cyril Vallez 2025-04-14 16:19:04 +02:00 committed by GitHub
parent 1897a02d83
commit 4e53840920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
103 changed files with 1164 additions and 795 deletions

View File

@ -2449,6 +2449,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._init_weights(module)
module._is_hf_initialized = True
@torch.no_grad()
def initialize_weights(self):
"""
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
is extremely error prone and inefficient.
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
`module.weight.data.zero_()`.
"""
if not hasattr(torch.nn.Module, "smart_apply"):
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
# to apply as we go down the graph
def smart_apply(self, fn):
for module in self.children():
# We found a sub-model: recursively dispatch its own init function now!
if hasattr(module, "_init_weights"):
module.smart_apply(module._initialize_weights)
else:
module.smart_apply(fn)
fn(self)
return self
torch.nn.Module.smart_apply = smart_apply
# Let the magic happen with this simple call
self.smart_apply(self._initialize_weights)
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
@ -3074,7 +3105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if _init_weights:
# Initialize weights
self.apply(self._initialize_weights)
self.initialize_weights()
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
@ -5286,9 +5317,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.apply(self._initialize_weights)
self.initialize_weights()
else:
self.apply(self._initialize_weights)
self.initialize_weights()
def get_parameter_or_buffer(self, target: str):
"""

View File

@ -679,12 +679,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
ARIA_TEXT_START_DOCSTRING = r"""
@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)

View File

@ -1255,12 +1255,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
class AriaPreTrainedModel(LlamaPreTrainedModel):
@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)

View File

@ -127,26 +127,19 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
_supports_static_cache = False
def _init_weights(self, module):
# important: this ported version of AyaVision isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/AyaVision/tree/main/aya_vision should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
@dataclass

View File

@ -113,6 +113,21 @@ class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
_supports_quantized_cache = False
_supports_static_cache = False
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass

View File

@ -1052,10 +1052,16 @@ class BambaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)
BAMBA_INPUTS_DOCSTRING = r"""

View File

@ -820,10 +820,16 @@ class BambaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)
BAMBA_INPUTS_DOCSTRING = r"""

View File

@ -423,22 +423,30 @@ class Blip2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, Blip2VisionEmbeddings):
if hasattr(self.config, "vision_config") and not isinstance(self.config, Blip2VisionConfig):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Blip2VisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(
module,
(
Blip2Model,
Blip2TextModelWithProjection,
Blip2VisionModelWithProjection,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
),
):
module.query_tokens.data.zero_()
BLIP_2_START_DOCSTRING = r"""

View File

@ -1056,12 +1056,16 @@ class ChameleonPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, ChameleonVQVAE):
module.apply(module._init_weights)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ChameleonRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
@ -1096,18 +1100,6 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)

View File

@ -416,6 +416,8 @@ class CoherePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, CohereLayerNorm):
module.weight.data.fill_(1.0)
COHERE_INPUTS_DOCSTRING = r"""

View File

@ -41,6 +41,7 @@ from ..llama.modeling_llama import (
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
eager_attention_forward,
)
@ -277,6 +278,21 @@ class CohereDecoderLayer(nn.Module):
return outputs
class CoherePreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, CohereLayerNorm):
module.weight.data.fill_(1.0)
class CohereModel(LlamaModel):
def __init__(self, config: CohereConfig):
super().__init__(config)

View File

@ -424,6 +424,8 @@ class Cohere2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Cohere2LayerNorm):
module.weight.data.fill_(1.0)
COHERE2_INPUTS_DOCSTRING = r"""

View File

@ -557,10 +557,10 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DeepseekV3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, DeepseekV3TopkRouter):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=std)
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""

View File

@ -347,10 +347,10 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DeepseekV3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, DeepseekV3TopkRouter):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=std)
class DeepseekV3Model(LlamaModel):

View File

@ -625,6 +625,13 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
module.weight.data.fill_(1.0)
elif isinstance(module, DiffLlamaAttention):
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
class DiffLlamaRotaryEmbedding(nn.Module):

View File

@ -431,6 +431,24 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
_supports_flex_attn = False
_supports_attention_backend = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
module.weight.data.fill_(1.0)
elif isinstance(module, DiffLlamaAttention):
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
class DiffLlamaModel(LlamaModel):
pass

View File

@ -852,7 +852,7 @@ class DPTPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):

View File

@ -1020,6 +1020,10 @@ class Emu3VQVAE(PreTrainedModel):
def _init_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
@ -1027,8 +1031,12 @@ class Emu3VQVAE(PreTrainedModel):
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_()
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def __init__(self, config: Emu3VQVAEConfig):
super().__init__(config)
@ -1198,9 +1206,7 @@ class Emu3PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
if isinstance(module, Emu3VQVAE):
module.apply(module._init_weights)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@ -1208,6 +1214,8 @@ class Emu3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Emu3RMSNorm): # noqa: F821
module.weight.data.fill_(1.0)
class Emu3RotaryEmbedding(nn.Module):

View File

@ -747,6 +747,10 @@ class Emu3VQVAE(PreTrainedModel):
def _init_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
@ -754,8 +758,12 @@ class Emu3VQVAE(PreTrainedModel):
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_()
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def __init__(self, config: Emu3VQVAEConfig):
super().__init__(config)
@ -894,9 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
if isinstance(module, Emu3VQVAE):
module.apply(module._init_weights)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@ -904,6 +910,8 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Emu3RMSNorm): # noqa: F821
module.weight.data.fill_(1.0)
EMU3_TEXT_INPUTS_DOCSTRING = r"""

View File

@ -381,6 +381,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GemmaRMSNorm):
module.weight.data.fill_(1.0)
GEMMA_INPUTS_DOCSTRING = r"""

View File

@ -426,6 +426,8 @@ class Gemma2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Gemma2RMSNorm):
module.weight.data.fill_(1.0)
GEMMA2_INPUTS_DOCSTRING = r"""

View File

@ -486,13 +486,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of Gemma2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
@ -502,6 +496,10 @@ class Gemma3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Gemma3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Gemma3MultiModalProjector):
module.mm_input_projection_weight.data.zero_()
GEMMA3_INPUTS_DOCSTRING = r"""

View File

@ -548,13 +548,7 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
]
def _init_weights(self, module):
# important: this ported version of Gemma2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
@ -564,6 +558,10 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Gemma3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Gemma3MultiModalProjector):
module.mm_input_projection_weight.data.zero_()
class Gemma3TextModel(Gemma2Model):

View File

@ -399,6 +399,8 @@ class GlmPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GlmRMSNorm):
module.weight.data.fill_(1.0)
GLM_INPUTS_DOCSTRING = r"""

View File

@ -407,6 +407,8 @@ class Glm4PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Glm4RMSNorm):
module.weight.data.fill_(1.0)
GLM4_INPUTS_DOCSTRING = r"""

View File

@ -591,26 +591,22 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/GotOcr2/tree/main/got_ocr2 should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, GotOcr2VisionAttention):
if module.use_rel_pos:
module.rel_pos_h.data.zero_()
module.rel_pos_w.data.zero_()
elif isinstance(module, GotOcr2VisionEncoder):
if module.pos_embed is not None:
module.pos_embed.data.zero_()
GOT_OCR2_INPUTS_DOCSTRING = r"""

View File

@ -276,7 +276,23 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
pass
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, GotOcr2VisionAttention):
if module.use_rel_pos:
module.rel_pos_h.data.zero_()
module.rel_pos_w.data.zero_()
elif isinstance(module, GotOcr2VisionEncoder):
if module.pos_embed is not None:
module.pos_embed.data.zero_()
GOT_OCR2_INPUTS_DOCSTRING = r"""

View File

@ -75,6 +75,9 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, GPTNeoXJapaneseAttention):
if module.dense_bias is not None:
module.dense_bias.data.zero_()
class GPTNeoXJapaneseAttention(nn.Module):

View File

@ -366,6 +366,8 @@ class GranitePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GraniteRMSNorm):
module.weight.data.fill_(1.0)
class GraniteRotaryEmbedding(nn.Module):

View File

@ -330,11 +330,15 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, GraniteSpeechEncoderProjector):
module.query.data.normal_()
GRANITE_SPEECH_INPUTS_DOCSTRING = r"""

View File

@ -833,8 +833,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
elif isinstance(module, GraniteMoeRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, GraniteMoeParallelExperts):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

View File

@ -745,8 +745,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
elif isinstance(module, GraniteMoeSharedRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, GraniteMoeSharedParallelExperts):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

View File

@ -384,6 +384,8 @@ class HeliumPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, HeliumRMSNorm):
module.weight.data.fill_(1.0)
HELIUM_INPUTS_DOCSTRING = r"""

View File

@ -44,7 +44,7 @@ from ...utils import (
)
from .configuration_idefics import IdeficsConfig
from .perceiver import IdeficsPerceiverResampler
from .vision import IdeficsVisionTransformer
from .vision import IdeficsVisionEmbeddings, IdeficsVisionTransformer
if is_torch_flex_attn_available():
@ -934,7 +934,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
# inference and fine-tuning - so the proper init weights code has been removed - the m4 code
# base should be used for training from scratch and it contains the correct code.
std = self.config.initializer_range
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@ -942,6 +942,25 @@ class IdeficsPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, IdeficsRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, IdeficsVisionEmbeddings):
module.class_embedding.data.normal_()
elif isinstance(module, IdeficsGatedCrossAttentionLayer):
if self.config.alpha_initializer == "zeros":
module.alpha_cross_attn.data.zero_()
module.alpha_dense.data.zero_()
elif self.config.alpha_initializer == "ones":
module.alpha_cross_attn.data.fill_(1.0)
module.alpha_dense.data.fill_(1.0)
elif self.config.alpha_initializer in {"normal", "gaussian", "random"}:
module.alpha_cross_attn.data.normal_(mean=0.0, std=self.config.alphas_initializer_range)
module.alpha_dense.data.normal_(mean=0.0, std=self.config.alphas_initializer_range)
elif isinstance(module, IdeficsPerceiverResampler):
module.latents.data.normal_()
LLAMA_INPUTS_DOCSTRING = r"""
@ -1495,7 +1514,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config, vision_model=None):

View File

@ -130,6 +130,8 @@ class Idefics2PerceiverConfig(PretrainedConfig):
Number of key-value heads in the perceiver attention block.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing all weight matrices in the model.
"""
model_type = "idefics2_perceiver"
@ -145,6 +147,7 @@ class Idefics2PerceiverConfig(PretrainedConfig):
resampler_head_dim=96,
num_key_value_heads=4,
attention_dropout=0.0,
initializer_range=0.02,
**kwargs,
):
self.hidden_act = hidden_act
@ -156,6 +159,7 @@ class Idefics2PerceiverConfig(PretrainedConfig):
self.num_key_value_heads = num_key_value_heads
self.resampler_head_dim = resampler_head_dim
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
if self.num_key_value_heads > self.resampler_n_heads:
raise ValueError(
f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"

View File

@ -517,14 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.get_text_config().initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
@ -534,6 +527,17 @@ class Idefics2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Idefics2RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.MultiheadAttention):
module._reset_parameters() # native torch init
elif isinstance(module, Idefics2MultiheadAttentionPoolingHead):
module.probe.data.normal_()
elif isinstance(module, Idefics2PerceiverResampler):
module.latents.data.fill_(1.0)
IDEFICS2_INPUTS_DOCSTRING = r"""

View File

@ -533,16 +533,8 @@ class Idefics3PreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.get_text_config().initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
@ -552,6 +544,11 @@ class Idefics3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Idefics3RMSNorm):
module.weight.data.fill_(1.0)
IDEFICS3_VISION_START_DOCSTRING = r"""

View File

@ -323,26 +323,24 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
"InstructBlipQFormerSelfOutput",
]
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, InstructBlipVisionEmbeddings):
if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVisionConfig):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, InstructBlipVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipForConditionalGeneration):
module.query_tokens.data.zero_()
INSTRUCTBLIP_START_DOCSTRING = r"""

View File

@ -130,44 +130,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
return embeddings
class InstructBlipVideoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
"InstructBlipVideoAttention",
"InstructBlipVideoQFormerMultiHeadAttention",
"InstructBlipVideoQFormerSelfOutput",
]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
if isinstance(module, InstructBlipVideoVisionEmbeddings):
if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVideoVisionConfig):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class InstructBlipVideoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -416,73 +378,6 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
"""
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values"
config_class = InstructBlipVideoVisionConfig
def __init__(self, config: InstructBlipVideoVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = InstructBlipVideoVisionEmbeddings(config)
self.encoder = InstructBlipVideoEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
@ -957,6 +852,194 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
return embeddings
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
[`InstructBlipVideoProcessor.__call__`] for details.
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
to serve as text prompt, which the Q-Former model will encode.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
provided to serve as text prompt, which the language model can continue.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
encoder-decoder language model (like T5) is used.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
Only relevant in case an encoder-decoder language model (like T5) is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
class InstructBlipVideoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
"InstructBlipVideoAttention",
"InstructBlipVideoQFormerMultiHeadAttention",
"InstructBlipVideoQFormerSelfOutput",
]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=factor)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, InstructBlipVideoVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipVideoForConditionalGeneration):
module.query_tokens.data.zero_()
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values"
config_class = InstructBlipVideoVisionConfig
def __init__(self, config: InstructBlipVideoVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = InstructBlipVideoVisionEmbeddings(config)
self.encoder = InstructBlipVideoEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
"""
Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
@ -1186,90 +1269,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
)
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
[`InstructBlipVideoProcessor.__call__`] for details.
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
to serve as text prompt, which the Q-Former model will encode.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
provided to serve as text prompt, which the language model can continue.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
encoder-decoder language model (like T5) is used.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
Only relevant in case an encoder-decoder language model (like T5) is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
@add_start_docstrings(
"""
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision

View File

@ -1115,6 +1115,13 @@ class JambaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, JambaRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, JambaMambaMixer):
A = torch.arange(1, module.ssm_state_size + 1)[None, :]
A = A.expand(module.intermediate_size, -1).contiguous()
module.A_log.data.copy_(torch.log(A))
module.D.data.fill_(1.0)
JAMBA_INPUTS_DOCSTRING = r"""

View File

@ -856,8 +856,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
elif isinstance(module, JetMoeRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, JetMoeParallelExperts):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

View File

@ -389,6 +389,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LlamaRMSNorm):
module.weight.data.fill_(1.0)
LLAMA_INPUTS_DOCSTRING = r"""

View File

@ -492,6 +492,17 @@ class Llama4PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Llama4TextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Llama4TextExperts):
module.gate_up_proj.data.normal_(mean=0.0, std=std)
module.down_proj.data.normal_(mean=0.0, std=std)
elif isinstance(module, Llama4VisionModel):
module.class_embedding.data.normal_(std=module.scale)
module.positional_embedding_vlm.data.normal_(std=module.scale)
LLAMA4_INPUTS_DOCSTRING = r"""

View File

@ -144,23 +144,12 @@ class LlavaPreTrainedModel(PreTrainedModel):
# important: this ported version of Llava isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
LLAVA_INPUTS_DOCSTRING = r"""

View File

@ -236,7 +236,6 @@ LLAVA_NEXT_START_DOCSTRING = r"""
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_START_DOCSTRING,
)
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next
class LlavaNextPreTrainedModel(PreTrainedModel):
config_class = LlavaNextConfig
base_model_prefix = "model"
@ -250,26 +249,15 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LlavaNextForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
LLAVA_NEXT_INPUTS_DOCSTRING = r"""

View File

@ -129,62 +129,6 @@ class LlavaNextVideoPooler(nn.Module):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next_video should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig):
super().__init__()
@ -207,6 +151,23 @@ class LlavaNextVideoMultiModalProjector(nn.Module):
return hidden_states
LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -394,6 +355,34 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,

View File

@ -24,6 +24,7 @@ from torch import nn
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextMultiModalProjector,
LlavaNextPreTrainedModel,
image_size_to_num_patches,
)
@ -222,10 +223,23 @@ class LlavaNextVideoPooler(nn.Module):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
pass
class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)

View File

@ -255,28 +255,17 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
_supports_sdpa = True
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LlavaOnevisionForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""

View File

@ -1412,31 +1412,22 @@ class MimiPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_static_cache = True
# Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if "weight" in name:
nn.init.xavier_uniform_(param)
elif "bias" in name:
nn.init.constant_(param, 0.0)
elif isinstance(module, MimiLayerScale):
module.scale.data.fill_(self.config.layer_scale_initial_scale)
MIMI_START_DOCSTRING = r"""

View File

@ -318,6 +318,8 @@ class MistralPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MistralRMSNorm):
module.weight.data.fill_(1.0)
class MistralRotaryEmbedding(nn.Module):

View File

@ -203,26 +203,14 @@ class Mistral3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Mistral3 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
MISTRAL3_INPUTS_DOCSTRING = r"""

View File

@ -20,7 +20,7 @@ from torch import nn
from ...activations import ACT2FN
from ...utils import is_torchdynamo_compiling, logging
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
from ..mistral.modeling_mistral import MistralRMSNorm
from .configuration_mistral3 import Mistral3Config
@ -100,6 +100,18 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
def get_image_features(
self,

View File

@ -474,6 +474,8 @@ class MixtralPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MixtralRMSNorm):
module.weight.data.fill_(1.0)
MIXTRAL_INPUTS_DOCSTRING = r"""

View File

@ -1029,7 +1029,8 @@ class MllamaPreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
@ -1038,15 +1039,25 @@ class MllamaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Parameter):
module.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, MllamaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, MllamaVisionModel):
nn.init.normal_(module.class_embedding.data, std=std)
elif isinstance(module, MllamaPrecomputedPositionEmbedding):
nn.init.normal_(module.embedding.data, std=std)
nn.init.zeros_(module.gate.data)
elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated:
nn.init.normal_(module.gate_attn.data, std=std)
nn.init.normal_(module.gate_ffn.data, std=std)
elif isinstance(module, MllamaCrossAttentionDecoderLayer):
module.cross_attn_attn_gate.data.zero_()
module.cross_attn_mlp_gate.data.zero_()
elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
if module.is_gated:
module.gate.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(

View File

@ -536,6 +536,10 @@ class MoonshinePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
module.weight.data.fill_(1.0)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -554,6 +554,10 @@ class MoonshinePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
module.weight.data.fill_(1.0)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -849,22 +849,19 @@ class MoshiPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, MoshiFlexibleLinear):
module.weight.data.normal_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MoshiRMSNorm):
module.weight.data.fill_(1.0)
MOSHI_START_DOCSTRING = r"""

View File

@ -623,6 +623,9 @@ class NemotronPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, NemotronLayerNorm1P):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
NEMOTRON_INPUTS_DOCSTRING = r"""

View File

@ -282,6 +282,40 @@ class OlmoDecoderLayer(nn.Module):
return outputs
class OlmoRotaryEmbedding(nn.Module):
def __init__(self, config: OlmoConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
OLMO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -329,40 +363,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
class OlmoRotaryEmbedding(nn.Module):
def __init__(self, config: OlmoConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
OLMO_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@ -15,6 +15,7 @@ from ..llama.modeling_llama import (
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
@ -114,10 +115,23 @@ class OlmoDecoderLayer(LlamaDecoderLayer):
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
class OlmoPreTrainedModel(LlamaPreTrainedModel):
class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
pass
class OlmoPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class OlmoModel(LlamaModel):
def __init__(self, config: OlmoConfig):
super().__init__(config)

View File

@ -286,6 +286,40 @@ class Olmo2DecoderLayer(nn.Module):
return outputs
class Olmo2RotaryEmbedding(nn.Module):
def __init__(self, config: Olmo2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
OLMO2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -331,40 +365,8 @@ class Olmo2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Olmo2RotaryEmbedding(nn.Module):
def __init__(self, config: Olmo2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
elif isinstance(module, Olmo2RMSNorm):
module.weight.data.fill_(1.0)
OLMO2_INPUTS_DOCSTRING = r"""

View File

@ -7,13 +7,14 @@ from ...cache_utils import Cache
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
from ..olmo.configuration_olmo import OlmoConfig
from ..olmo.modeling_olmo import (
OlmoAttention,
OlmoDecoderLayer,
OlmoForCausalLM,
OlmoModel,
OlmoRotaryEmbedding,
apply_rotary_pos_emb,
)
@ -287,6 +288,14 @@ class Olmo2DecoderLayer(OlmoDecoderLayer):
return outputs
class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
pass
class Olmo2PreTrainedModel(LlamaPreTrainedModel):
pass
# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
# standard layer norm for the output norm.
class Olmo2Model(OlmoModel):

View File

@ -747,6 +747,8 @@ class OlmoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, OlmoeRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -511,6 +511,9 @@ class OPTPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
OPT_INPUTS_DOCSTRING = r"""

View File

@ -199,23 +199,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
# inference and fine-tuning
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
PALIGEMMA_INPUTS_DOCSTRING = r"""

View File

@ -412,6 +412,9 @@ class PersimmonPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
PERSIMMON_INPUTS_DOCSTRING = r"""

View File

@ -279,6 +279,40 @@ class PhiDecoderLayer(nn.Module):
return outputs
class PhiRotaryEmbedding(nn.Module):
def __init__(self, config: PhiConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
PHI_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -324,40 +358,9 @@ class PhiPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class PhiRotaryEmbedding(nn.Module):
def __init__(self, config: PhiConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
PHI_INPUTS_DOCSTRING = r"""

View File

@ -20,6 +20,7 @@ from ..llama.modeling_llama import (
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward, # copied from Llama
)
@ -170,10 +171,26 @@ class PhiDecoderLayer(nn.Module):
return outputs
class PhiPreTrainedModel(LlamaPreTrainedModel):
class PhiRotaryEmbedding(LlamaRotaryEmbedding):
pass
class PhiPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class PhiModel(LlamaModel):
def __init__(self, config: PhiConfig):
super().__init__(config)

View File

@ -373,6 +373,8 @@ class Phi3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Phi3RMSNorm):
module.weight.data.fill_(1.0)
class Phi3RotaryEmbedding(nn.Module):

View File

@ -1030,6 +1030,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv):
module.b1.data.zero_()
module.b2.data.zero_()
def unfold_tensor(tensor, max_seq_len):
@ -1607,6 +1610,40 @@ class Phi4MultimodalFeatureEmbedding(nn.Module):
return inputs_embeds
class Phi4MultimodalRotaryEmbedding(nn.Module):
def __init__(self, config: Phi4MultimodalConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
PHI4_MULTIMODAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -1653,40 +1690,11 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Phi4MultimodalRotaryEmbedding(nn.Module):
def __init__(self, config: Phi4MultimodalConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
elif isinstance(module, Phi4MultimodalRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalImageEmbedding):
module.global_img_feature_extensor.data.zero_()
module.sub_img_feature_extensor.data.zero_()
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""

View File

@ -40,7 +40,14 @@ from ...utils import (
replace_return_docstrings,
)
from ..phi3.configuration_phi3 import Phi3Config
from ..phi3.modeling_phi3 import Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
from ..phi3.modeling_phi3 import (
Phi3DecoderLayer,
Phi3ForCausalLM,
Phi3Model,
Phi3PreTrainedModel,
Phi3RMSNorm,
Phi3RotaryEmbedding,
)
from ..siglip.configuration_siglip import SiglipVisionConfig
from ..siglip.modeling_siglip import (
SiglipEncoder,
@ -1133,6 +1140,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv):
module.b1.data.zero_()
module.b2.data.zero_()
class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
@ -1519,6 +1529,28 @@ PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
"""
class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding):
pass
class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Phi4MultimodalRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalImageEmbedding):
module.global_img_feature_extensor.data.zero_()
module.sub_img_feature_extensor.data.zero_()
class Phi4MultimodalModel(Phi3Model, nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`]
@ -1829,7 +1861,7 @@ __all__ = [
"Phi4MultimodalAudioModel",
"Phi4MultimodalVisionPreTrainedModel",
"Phi4MultimodalVisionModel",
"Phi4MultimodalPreTrainedModel", # noqa
"Phi4MultimodalPreTrainedModel",
"Phi4MultimodalModel",
"Phi4MultimodalForCausalLM",
"Phi4MultimodalVisionConfig",

View File

@ -923,6 +923,9 @@ class PhimoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
PHIMOE_INPUTS_DOCSTRING = r"""

View File

@ -383,20 +383,13 @@ class PixtralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PixtralAttentionLayer"]
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.initializer_range
)
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, PixtralRMSNorm):
module.weight.data.fill_(1.0)
PIXTRAL_INPUTS_DOCSTRING = r"""

View File

@ -254,15 +254,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class PromptDepthAnythingReassembleLayer(nn.Module):

View File

@ -210,15 +210,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class PromptDepthAnythingReassembleLayer(nn.Module):

View File

@ -331,6 +331,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2RotaryEmbedding(nn.Module):

View File

@ -92,6 +92,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@ -108,6 +109,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig):

View File

@ -75,6 +75,26 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
Qwen2_5Omni_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -112,7 +132,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
# inference and fine-tuning - so the proper init weights code has been removed
std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@ -120,6 +140,11 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
@ -1102,26 +1127,6 @@ class Qwen2_5OmniMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5OmniVisionAttention,
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,

View File

@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLModel,
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLVisionBlock,
Qwen2RMSNorm,
)
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
@ -130,6 +131,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(
@ -145,6 +147,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig):
window_size,
out_hidden_size,
fullatt_block_indexes,
initializer_range=initializer_range,
**kwargs,
)
del self.tokens_per_second
@ -1027,7 +1030,7 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
# inference and fine-tuning - so the proper init weights code has been removed
std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@ -1035,6 +1038,11 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):

View File

@ -46,6 +46,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@ -63,6 +64,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig):

View File

@ -388,6 +388,8 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):

View File

@ -89,6 +89,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@ -106,6 +107,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(Qwen2VLConfig):
@ -224,7 +226,18 @@ class Qwen2_5_VLVisionBlock(nn.Module):
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
pass
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):

View File

@ -779,6 +779,8 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen2MoeRMSNorm):
module.weight.data.fill_(1.0)
QWEN2MOE_INPUTS_DOCSTRING = r"""

View File

@ -38,6 +38,7 @@ class Qwen2VLVisionConfig(PretrainedConfig):
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@ -52,6 +53,7 @@ class Qwen2VLVisionConfig(PretrainedConfig):
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.initializer_range = initializer_range
class Qwen2VLConfig(PretrainedConfig):

View File

@ -914,6 +914,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
@ -922,6 +923,11 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Qwen2RMSNorm):
module.weight.data.fill_(1.0)
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):

View File

@ -358,6 +358,8 @@ class Qwen3PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen3RMSNorm):
module.weight.data.fill_(1.0)
class Qwen3RotaryEmbedding(nn.Module):

View File

@ -488,6 +488,8 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen3MoeRMSNorm):
module.weight.data.fill_(1.0)
QWEN3_MOE_INPUTS_DOCSTRING = r"""

View File

@ -581,6 +581,13 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RecurrentGemmaRMSNorm):
module.weight.data.fill_(1.0)
def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers

View File

@ -1040,8 +1040,6 @@ class RTDetrPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initalize the weights"""
"""initialize linear layer bias value according to a given probability value."""
if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
if module.class_embed is not None:
for layer in module.class_embed:
@ -1055,7 +1053,7 @@ class RTDetrPreTrainedModel(PreTrainedModel):
nn.init.constant_(layer.layers[-1].weight, 0)
nn.init.constant_(layer.layers[-1].bias, 0)
if isinstance(module, RTDetrMultiscaleDeformableAttention):
elif isinstance(module, RTDetrMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
@ -1078,17 +1076,21 @@ class RTDetrPreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
if isinstance(module, RTDetrModel):
elif isinstance(module, RTDetrModel):
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
bias = float(-math.log((1 - prior_prob) / prior_prob))
nn.init.xavier_uniform_(module.enc_score_head.weight)
nn.init.constant_(module.enc_score_head.bias, bias)
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
nn.init.xavier_uniform_(module.weight_embedding.weight)
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:

View File

@ -1314,8 +1314,6 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initalize the weights"""
"""initialize linear layer bias value according to a given probability value."""
if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)):
if module.class_embed is not None:
for layer in module.class_embed:
@ -1329,7 +1327,7 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
nn.init.constant_(layer.layers[-1].weight, 0)
nn.init.constant_(layer.layers[-1].bias, 0)
if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
elif isinstance(module, RTDetrV2MultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
@ -1352,17 +1350,21 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
if isinstance(module, RTDetrV2Model):
elif isinstance(module, RTDetrV2Model):
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
bias = float(-math.log((1 - prior_prob) / prior_prob))
nn.init.xavier_uniform_(module.enc_score_head.weight)
nn.init.constant_(module.enc_score_head.bias, bias)
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
nn.init.xavier_uniform_(module.weight_embedding.weight)
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:

View File

@ -80,14 +80,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.get_text_config().initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
@ -97,6 +90,9 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class SmolVLMVisionEmbeddings(nn.Module):

View File

@ -94,7 +94,20 @@ class SmolVLMVisionConfig(Idefics3VisionConfig):
class SmolVLMPreTrainedModel(Idefics3PreTrainedModel):
pass
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class SmolVLMVisionTransformer(Idefics3VisionTransformer):

View File

@ -666,6 +666,9 @@ class StableLmPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
STABLELM_INPUTS_DOCSTRING = r"""

View File

@ -275,6 +275,40 @@ class Starcoder2DecoderLayer(nn.Module):
return outputs
class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, config: Starcoder2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
STARCODER2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -320,40 +354,9 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, config: Starcoder2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
STARCODER2_INPUTS_DOCSTRING = r"""

View File

@ -41,6 +41,8 @@ from ..mistral.modeling_mistral import (
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
MistralPreTrainedModel,
MistralRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
@ -143,6 +145,26 @@ class Starcoder2DecoderLayer(MistralDecoderLayer):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
class Starcoder2RotaryEmbedding(MistralRotaryEmbedding):
pass
class Starcoder2PreTrainedModel(MistralPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined

View File

@ -166,15 +166,6 @@ class UperNetHead(nn.Module):
padding=1,
)
def init_weights(self):
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
def psp_forward(self, inputs):
x = inputs[-1]
psp_outs = [x]
@ -266,15 +257,6 @@ class UperNetFCNHead(nn.Module):
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
def init_weights(self):
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
# just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index]
@ -296,18 +278,13 @@ class UperNetPreTrainedModel(PreTrainedModel):
_no_split_modules = []
def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
if module.auxiliary_head is not None:
module.auxiliary_head.init_weights()
def init_weights(self):
"""Initialize the weights"""
self.backbone.init_weights()
self.decode_head.init_weights()
if self.auxiliary_head is not None:
self.auxiliary_head.init_weights()
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
UPERNET_START_DOCSTRING = r"""

View File

@ -128,7 +128,6 @@ VIPLLAVA_START_DOCSTRING = r"""
"The bare VipLlava Model outputting raw hidden-states without any specific head on top.",
VIPLLAVA_START_DOCSTRING,
)
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->VipLlava,llava->vipllava
class VipLlavaPreTrainedModel(PreTrainedModel):
config_class = VipLlavaConfig
base_model_prefix = "model"
@ -142,26 +141,15 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/vipllava should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
VIPLLAVA_INPUTS_DOCSTRING = r"""

View File

@ -786,10 +786,14 @@ class WhisperPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, WhisperEncoder):
with torch.no_grad():
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape))
elif isinstance(module, WhisperForAudioClassification):
if self.config.use_weighted_layer_sum:
module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""

View File

@ -850,10 +850,9 @@ class ZambaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, ZambaRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, ZambaMambaMixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
module.x_proj_weight.data.normal_(mean=0.0, std=std)
dt_init_std = self.config.mamba_dt_rank**-0.5
nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std)
@ -866,10 +865,12 @@ class ZambaPreTrainedModel(PreTrainedModel):
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
module.dt_proj_bias.data.copy_(inv_dt)
with torch.no_grad():
module.dt_proj_bias.copy_(inv_dt)
module.dt_proj_bias._no_reinit = True
A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(module.intermediate_size, -1).contiguous()
module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1))
module.D.data.fill_(1.0)
@classmethod
@classmethod

View File

@ -1225,10 +1225,9 @@ class Zamba2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)):
module.weight.data.fill_(1.0)
elif isinstance(module, Zamba2MambaMixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
dt = torch.exp(
torch.rand(self.config.n_mamba_heads)
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
@ -1236,10 +1235,11 @@ class Zamba2PreTrainedModel(PreTrainedModel):
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
module.dt_bias.data.copy_(inv_dt)
with torch.no_grad():
module.dt_bias.copy_(inv_dt)
module.dt_bias._no_reinit = True
A = torch.arange(1, module.num_heads + 1)
module.A_log.data.copy_(torch.log(A))
module.D.data.fill_(1.0)
ZAMBA2_START_DOCSTRING = r"""

Some files were not shown because too many files have changed in this diff Show More