mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Detect and fix most _init_weights()
issues - make it work for composite models (#37070)
* Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
This commit is contained in:
parent
1897a02d83
commit
4e53840920
@ -2449,6 +2449,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._init_weights(module)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_weights(self):
|
||||
"""
|
||||
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
|
||||
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
|
||||
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
|
||||
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
|
||||
is extremely error prone and inefficient.
|
||||
|
||||
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
|
||||
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
|
||||
`module.weight.data.zero_()`.
|
||||
"""
|
||||
if not hasattr(torch.nn.Module, "smart_apply"):
|
||||
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
|
||||
# to apply as we go down the graph
|
||||
def smart_apply(self, fn):
|
||||
for module in self.children():
|
||||
# We found a sub-model: recursively dispatch its own init function now!
|
||||
if hasattr(module, "_init_weights"):
|
||||
module.smart_apply(module._initialize_weights)
|
||||
else:
|
||||
module.smart_apply(fn)
|
||||
fn(self)
|
||||
return self
|
||||
|
||||
torch.nn.Module.smart_apply = smart_apply
|
||||
|
||||
# Let the magic happen with this simple call
|
||||
self.smart_apply(self._initialize_weights)
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
@ -3074,7 +3105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
if _init_weights:
|
||||
# Initialize weights
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
|
||||
# Tie weights should be skipped when not initializing all weights
|
||||
# since from_pretrained(...) calls tie weights anyways
|
||||
@ -5286,9 +5317,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
)
|
||||
)
|
||||
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
else:
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
|
@ -679,12 +679,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, AriaTextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
ARIA_TEXT_START_DOCSTRING = r"""
|
||||
@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AriaProjector):
|
||||
nn.init.trunc_normal_(module.query, std=std)
|
||||
|
||||
|
@ -1255,12 +1255,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, AriaTextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AriaProjector):
|
||||
nn.init.trunc_normal_(module.query, std=std)
|
||||
|
||||
|
@ -127,26 +127,19 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of AyaVision isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/AyaVision/tree/main/aya_vision should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -113,6 +113,21 @@ class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
|
||||
_supports_quantized_cache = False
|
||||
_supports_static_cache = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
@ -1052,10 +1052,16 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
BAMBA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -820,10 +820,16 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
BAMBA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -423,22 +423,30 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, Blip2VisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config") and not isinstance(self.config, Blip2VisionConfig):
|
||||
factor = self.config.vision_config.initializer_range
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Blip2VisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
Blip2Model,
|
||||
Blip2TextModelWithProjection,
|
||||
Blip2VisionModelWithProjection,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
),
|
||||
):
|
||||
module.query_tokens.data.zero_()
|
||||
|
||||
|
||||
BLIP_2_START_DOCSTRING = r"""
|
||||
|
@ -1056,12 +1056,16 @@ class ChameleonPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, ChameleonVQVAE):
|
||||
module.apply(module._init_weights)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ChameleonRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
@ -1096,18 +1100,6 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
|
||||
config_class = ChameleonVQVAEConfig
|
||||
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.GroupNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -416,6 +416,8 @@ class CoherePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, CohereLayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
COHERE_INPUTS_DOCSTRING = r"""
|
||||
|
@ -41,6 +41,7 @@ from ..llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRotaryEmbedding,
|
||||
eager_attention_forward,
|
||||
)
|
||||
@ -277,6 +278,21 @@ class CohereDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class CoherePreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, CohereLayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class CohereModel(LlamaModel):
|
||||
def __init__(self, config: CohereConfig):
|
||||
super().__init__(config)
|
||||
|
@ -424,6 +424,8 @@ class Cohere2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Cohere2LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
COHERE2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -557,10 +557,10 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DeepseekV3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DeepseekV3TopkRouter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
|
||||
|
@ -347,10 +347,10 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DeepseekV3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DeepseekV3TopkRouter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
class DeepseekV3Model(LlamaModel):
|
||||
|
@ -625,6 +625,13 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DiffLlamaAttention):
|
||||
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
|
||||
|
||||
|
||||
class DiffLlamaRotaryEmbedding(nn.Module):
|
||||
|
@ -431,6 +431,24 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_flex_attn = False
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DiffLlamaAttention):
|
||||
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
|
||||
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
|
||||
|
||||
|
||||
class DiffLlamaModel(LlamaModel):
|
||||
pass
|
||||
|
@ -852,7 +852,7 @@ class DPTPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
|
||||
|
@ -1020,6 +1020,10 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(module.bias, -bound, bound)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
||||
if module.bias is not None:
|
||||
@ -1027,8 +1031,12 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(module.bias, -bound, bound)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
||||
nn.init.constant_(module.weight, 1)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_()
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def __init__(self, config: Emu3VQVAEConfig):
|
||||
super().__init__(config)
|
||||
@ -1198,9 +1206,7 @@ class Emu3PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
if isinstance(module, Emu3VQVAE):
|
||||
module.apply(module._init_weights)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
@ -1208,6 +1214,8 @@ class Emu3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Emu3RMSNorm): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Emu3RotaryEmbedding(nn.Module):
|
||||
|
@ -747,6 +747,10 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(module.bias, -bound, bound)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
||||
if module.bias is not None:
|
||||
@ -754,8 +758,12 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(module.bias, -bound, bound)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
||||
nn.init.constant_(module.weight, 1)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_()
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def __init__(self, config: Emu3VQVAEConfig):
|
||||
super().__init__(config)
|
||||
@ -894,9 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
if isinstance(module, Emu3VQVAE):
|
||||
module.apply(module._init_weights)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
@ -904,6 +910,8 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Emu3RMSNorm): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
EMU3_TEXT_INPUTS_DOCSTRING = r"""
|
||||
|
@ -381,6 +381,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GemmaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GEMMA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -426,6 +426,8 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Gemma2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GEMMA2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -486,13 +486,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -502,6 +496,10 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Gemma3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Gemma3MultiModalProjector):
|
||||
module.mm_input_projection_weight.data.zero_()
|
||||
|
||||
|
||||
GEMMA3_INPUTS_DOCSTRING = r"""
|
||||
|
@ -548,13 +548,7 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -564,6 +558,10 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Gemma3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Gemma3MultiModalProjector):
|
||||
module.mm_input_projection_weight.data.zero_()
|
||||
|
||||
|
||||
class Gemma3TextModel(Gemma2Model):
|
||||
|
@ -399,6 +399,8 @@ class GlmPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GlmRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GLM_INPUTS_DOCSTRING = r"""
|
||||
|
@ -407,6 +407,8 @@ class Glm4PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Glm4RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GLM4_INPUTS_DOCSTRING = r"""
|
||||
|
@ -591,26 +591,22 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/GotOcr2/tree/main/got_ocr2 should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GotOcr2VisionAttention):
|
||||
if module.use_rel_pos:
|
||||
module.rel_pos_h.data.zero_()
|
||||
module.rel_pos_w.data.zero_()
|
||||
elif isinstance(module, GotOcr2VisionEncoder):
|
||||
if module.pos_embed is not None:
|
||||
module.pos_embed.data.zero_()
|
||||
|
||||
|
||||
GOT_OCR2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -276,7 +276,23 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
|
||||
|
||||
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
|
||||
pass
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GotOcr2VisionAttention):
|
||||
if module.use_rel_pos:
|
||||
module.rel_pos_h.data.zero_()
|
||||
module.rel_pos_w.data.zero_()
|
||||
elif isinstance(module, GotOcr2VisionEncoder):
|
||||
if module.pos_embed is not None:
|
||||
module.pos_embed.data.zero_()
|
||||
|
||||
|
||||
GOT_OCR2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -75,6 +75,9 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, GPTNeoXJapaneseAttention):
|
||||
if module.dense_bias is not None:
|
||||
module.dense_bias.data.zero_()
|
||||
|
||||
|
||||
class GPTNeoXJapaneseAttention(nn.Module):
|
||||
|
@ -366,6 +366,8 @@ class GranitePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GraniteRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class GraniteRotaryEmbedding(nn.Module):
|
||||
|
@ -330,11 +330,15 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GraniteSpeechEncoderProjector):
|
||||
module.query.data.normal_()
|
||||
|
||||
|
||||
GRANITE_SPEECH_INPUTS_DOCSTRING = r"""
|
||||
|
@ -833,8 +833,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GraniteMoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, GraniteMoeParallelExperts):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
@ -745,8 +745,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GraniteMoeSharedRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, GraniteMoeSharedParallelExperts):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
@ -384,6 +384,8 @@ class HeliumPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, HeliumRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
HELIUM_INPUTS_DOCSTRING = r"""
|
||||
|
@ -44,7 +44,7 @@ from ...utils import (
|
||||
)
|
||||
from .configuration_idefics import IdeficsConfig
|
||||
from .perceiver import IdeficsPerceiverResampler
|
||||
from .vision import IdeficsVisionTransformer
|
||||
from .vision import IdeficsVisionEmbeddings, IdeficsVisionTransformer
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
@ -934,7 +934,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the m4 code
|
||||
# base should be used for training from scratch and it contains the correct code.
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
@ -942,6 +942,25 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, IdeficsRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, IdeficsVisionEmbeddings):
|
||||
module.class_embedding.data.normal_()
|
||||
elif isinstance(module, IdeficsGatedCrossAttentionLayer):
|
||||
if self.config.alpha_initializer == "zeros":
|
||||
module.alpha_cross_attn.data.zero_()
|
||||
module.alpha_dense.data.zero_()
|
||||
elif self.config.alpha_initializer == "ones":
|
||||
module.alpha_cross_attn.data.fill_(1.0)
|
||||
module.alpha_dense.data.fill_(1.0)
|
||||
elif self.config.alpha_initializer in {"normal", "gaussian", "random"}:
|
||||
module.alpha_cross_attn.data.normal_(mean=0.0, std=self.config.alphas_initializer_range)
|
||||
module.alpha_dense.data.normal_(mean=0.0, std=self.config.alphas_initializer_range)
|
||||
elif isinstance(module, IdeficsPerceiverResampler):
|
||||
module.latents.data.normal_()
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
@ -1495,7 +1514,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
|
||||
|
||||
class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
def __init__(self, config, vision_model=None):
|
||||
|
@ -130,6 +130,8 @@ class Idefics2PerceiverConfig(PretrainedConfig):
|
||||
Number of key-value heads in the perceiver attention block.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation for initializing all weight matrices in the model.
|
||||
"""
|
||||
|
||||
model_type = "idefics2_perceiver"
|
||||
@ -145,6 +147,7 @@ class Idefics2PerceiverConfig(PretrainedConfig):
|
||||
resampler_head_dim=96,
|
||||
num_key_value_heads=4,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_act = hidden_act
|
||||
@ -156,6 +159,7 @@ class Idefics2PerceiverConfig(PretrainedConfig):
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.resampler_head_dim = resampler_head_dim
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
if self.num_key_value_heads > self.resampler_n_heads:
|
||||
raise ValueError(
|
||||
f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
|
||||
|
@ -517,14 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.get_text_config().initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -534,6 +527,17 @@ class Idefics2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Idefics2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
module._reset_parameters() # native torch init
|
||||
elif isinstance(module, Idefics2MultiheadAttentionPoolingHead):
|
||||
module.probe.data.normal_()
|
||||
elif isinstance(module, Idefics2PerceiverResampler):
|
||||
module.latents.data.fill_(1.0)
|
||||
|
||||
|
||||
IDEFICS2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -533,16 +533,8 @@ class Idefics3PreTrainedModel(PreTrainedModel):
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
|
||||
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.get_text_config().initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -552,6 +544,11 @@ class Idefics3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Idefics3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
IDEFICS3_VISION_START_DOCSTRING = r"""
|
||||
|
@ -323,26 +323,24 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
|
||||
"InstructBlipQFormerSelfOutput",
|
||||
]
|
||||
|
||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, InstructBlipVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVisionConfig):
|
||||
factor = self.config.vision_config.initializer_range
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, InstructBlipVisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
elif isinstance(module, InstructBlipForConditionalGeneration):
|
||||
module.query_tokens.data.zero_()
|
||||
|
||||
|
||||
INSTRUCTBLIP_START_DOCSTRING = r"""
|
||||
|
@ -130,44 +130,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = InstructBlipVideoConfig
|
||||
base_model_prefix = "blip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
_no_split_modules = [
|
||||
"InstructBlipVideoQFormerEmbeddings",
|
||||
"InstructBlipVideoAttention",
|
||||
"InstructBlipVideoQFormerMultiHeadAttention",
|
||||
"InstructBlipVideoQFormerSelfOutput",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, InstructBlipVideoVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVideoVisionConfig):
|
||||
factor = self.config.vision_config.initializer_range
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class InstructBlipVideoAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -416,73 +378,6 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = InstructBlipVideoVisionConfig
|
||||
|
||||
def __init__(self, config: InstructBlipVideoVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = InstructBlipVideoVisionEmbeddings(config)
|
||||
self.encoder = InstructBlipVideoEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False):
|
||||
super().__init__()
|
||||
@ -957,6 +852,194 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||
|
||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||
to serve as text prompt, which the Q-Former model will encode.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||
provided to serve as text prompt, which the language model can continue.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||
encoder-decoder language model (like T5) is used.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
|
||||
class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = InstructBlipVideoConfig
|
||||
base_model_prefix = "blip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
_no_split_modules = [
|
||||
"InstructBlipVideoQFormerEmbeddings",
|
||||
"InstructBlipVideoAttention",
|
||||
"InstructBlipVideoQFormerMultiHeadAttention",
|
||||
"InstructBlipVideoQFormerSelfOutput",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, InstructBlipVideoVisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
elif isinstance(module, InstructBlipVideoForConditionalGeneration):
|
||||
module.query_tokens.data.zero_()
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = InstructBlipVideoVisionConfig
|
||||
|
||||
def __init__(self, config: InstructBlipVideoVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = InstructBlipVideoVisionEmbeddings(config)
|
||||
self.encoder = InstructBlipVideoEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
|
||||
@ -1186,90 +1269,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
|
||||
[`InstructBlipVideoProcessor.__call__`] for details.
|
||||
|
||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||
to serve as text prompt, which the Q-Former model will encode.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||
provided to serve as text prompt, which the language model can continue.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
||||
encoder-decoder language model (like T5) is used.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||
|
@ -1115,6 +1115,13 @@ class JambaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, JambaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, JambaMambaMixer):
|
||||
A = torch.arange(1, module.ssm_state_size + 1)[None, :]
|
||||
A = A.expand(module.intermediate_size, -1).contiguous()
|
||||
module.A_log.data.copy_(torch.log(A))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
JAMBA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -856,8 +856,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, JetMoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, JetMoeParallelExperts):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
@ -389,6 +389,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LlamaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -492,6 +492,17 @@ class Llama4PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Llama4TextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Llama4TextExperts):
|
||||
module.gate_up_proj.data.normal_(mean=0.0, std=std)
|
||||
module.down_proj.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, Llama4VisionModel):
|
||||
module.class_embedding.data.normal_(std=module.scale)
|
||||
module.positional_embedding_vlm.data.normal_(std=module.scale)
|
||||
|
||||
|
||||
LLAMA4_INPUTS_DOCSTRING = r"""
|
||||
|
@ -144,23 +144,12 @@ class LlavaPreTrainedModel(PreTrainedModel):
|
||||
# important: this ported version of Llava isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
LLAVA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -236,7 +236,6 @@ LLAVA_NEXT_START_DOCSTRING = r"""
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAVA_NEXT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next
|
||||
class LlavaNextPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlavaNextConfig
|
||||
base_model_prefix = "model"
|
||||
@ -250,26 +249,15 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNext isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LlavaNextForConditionalGeneration):
|
||||
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
|
||||
module.image_newline.data.normal_(mean=0.0, std=embed_std)
|
||||
|
||||
|
||||
LLAVA_NEXT_INPUTS_DOCSTRING = r"""
|
||||
|
@ -129,62 +129,6 @@ class LlavaNextVideoPooler(nn.Module):
|
||||
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
||||
)
|
||||
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlavaNextVideoConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlavaNextVideoVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next_video should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class LlavaNextVideoMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaNextVideoConfig):
|
||||
super().__init__()
|
||||
@ -207,6 +151,23 @@ class LlavaNextVideoMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
@ -394,6 +355,34 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
||||
)
|
||||
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlavaNextVideoConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlavaNextVideoVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
|
||||
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
|
||||
module.image_newline.data.normal_(mean=0.0, std=embed_std)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
LlavaNextCausalLMOutputWithPast,
|
||||
LlavaNextForConditionalGeneration,
|
||||
LlavaNextMultiModalProjector,
|
||||
LlavaNextPreTrainedModel,
|
||||
image_size_to_num_patches,
|
||||
)
|
||||
@ -222,10 +223,23 @@ class LlavaNextVideoPooler(nn.Module):
|
||||
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
|
||||
class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
|
||||
pass
|
||||
|
||||
|
||||
class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
|
||||
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
|
||||
module.image_newline.data.normal_(mean=0.0, std=embed_std)
|
||||
|
||||
|
||||
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
|
||||
super().__init__(config, **super_kwargs)
|
||||
|
@ -255,28 +255,17 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNext isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LlavaOnevisionForConditionalGeneration):
|
||||
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
|
||||
module.image_newline.data.normal_(mean=0.0, std=embed_std)
|
||||
|
||||
|
||||
LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
|
||||
|
@ -1412,31 +1412,22 @@ class MimiPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
# Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
nn.init.kaiming_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||
nn.init.uniform_(module.bias, a=-k, b=k)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LSTM):
|
||||
for name, param in module.named_parameters():
|
||||
if "weight" in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(param, 0.0)
|
||||
elif isinstance(module, MimiLayerScale):
|
||||
module.scale.data.fill_(self.config.layer_scale_initial_scale)
|
||||
|
||||
|
||||
MIMI_START_DOCSTRING = r"""
|
||||
|
@ -318,6 +318,8 @@ class MistralPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, MistralRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class MistralRotaryEmbedding(nn.Module):
|
||||
|
@ -203,26 +203,14 @@ class Mistral3PreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Mistral3 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Mistral3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
MISTRAL3_INPUTS_DOCSTRING = r"""
|
||||
|
@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
|
||||
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
|
||||
from ..mistral.modeling_mistral import MistralRMSNorm
|
||||
from .configuration_mistral3 import Mistral3Config
|
||||
|
||||
@ -100,6 +100,18 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Mistral3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
def get_image_features(
|
||||
self,
|
||||
|
@ -474,6 +474,8 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, MixtralRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
MIXTRAL_INPUTS_DOCSTRING = r"""
|
||||
|
@ -1029,7 +1029,8 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
_supports_quantized_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
@ -1038,15 +1039,25 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.Parameter):
|
||||
module.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, MllamaTextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, MllamaVisionModel):
|
||||
nn.init.normal_(module.class_embedding.data, std=std)
|
||||
elif isinstance(module, MllamaPrecomputedPositionEmbedding):
|
||||
nn.init.normal_(module.embedding.data, std=std)
|
||||
nn.init.zeros_(module.gate.data)
|
||||
elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated:
|
||||
nn.init.normal_(module.gate_attn.data, std=std)
|
||||
nn.init.normal_(module.gate_ffn.data, std=std)
|
||||
elif isinstance(module, MllamaCrossAttentionDecoderLayer):
|
||||
module.cross_attn_attn_gate.data.zero_()
|
||||
module.cross_attn_mlp_gate.data.zero_()
|
||||
elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
|
||||
if module.is_gated:
|
||||
module.gate.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
|
@ -536,6 +536,10 @@ class MoonshinePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
|
||||
module.weight.data.fill_(1.0)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
@ -554,6 +554,10 @@ class MoonshinePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
|
||||
module.weight.data.fill_(1.0)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
@ -849,22 +849,19 @@ class MoshiPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||
nn.init.uniform_(module.bias, a=-k, b=k)
|
||||
elif isinstance(module, MoshiFlexibleLinear):
|
||||
module.weight.data.normal_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, MoshiRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
MOSHI_START_DOCSTRING = r"""
|
||||
|
@ -623,6 +623,9 @@ class NemotronPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, NemotronLayerNorm1P):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
NEMOTRON_INPUTS_DOCSTRING = r"""
|
||||
|
@ -282,6 +282,40 @@ class OlmoDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class OlmoRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: OlmoConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
OLMO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -329,40 +363,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class OlmoRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: OlmoConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
OLMO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -15,6 +15,7 @@ from ..llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
@ -114,10 +115,23 @@ class OlmoDecoderLayer(LlamaDecoderLayer):
|
||||
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
|
||||
class OlmoPreTrainedModel(LlamaPreTrainedModel):
|
||||
class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class OlmoPreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class OlmoModel(LlamaModel):
|
||||
def __init__(self, config: OlmoConfig):
|
||||
super().__init__(config)
|
||||
|
@ -286,6 +286,40 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class Olmo2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Olmo2Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
OLMO2_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -331,40 +365,8 @@ class Olmo2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Olmo2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Olmo2Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
elif isinstance(module, Olmo2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
OLMO2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -7,13 +7,14 @@ from ...cache_utils import Cache
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import logging
|
||||
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
|
||||
from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
|
||||
from ..olmo.configuration_olmo import OlmoConfig
|
||||
from ..olmo.modeling_olmo import (
|
||||
OlmoAttention,
|
||||
OlmoDecoderLayer,
|
||||
OlmoForCausalLM,
|
||||
OlmoModel,
|
||||
OlmoRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
@ -287,6 +288,14 @@ class Olmo2DecoderLayer(OlmoDecoderLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Olmo2PreTrainedModel(LlamaPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
|
||||
# standard layer norm for the output norm.
|
||||
class Olmo2Model(OlmoModel):
|
||||
|
@ -747,6 +747,8 @@ class OlmoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, OlmoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
@ -511,6 +511,9 @@ class OPTPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
OPT_INPUTS_DOCSTRING = r"""
|
||||
|
@ -199,23 +199,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
PALIGEMMA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -412,6 +412,9 @@ class PersimmonPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
PERSIMMON_INPUTS_DOCSTRING = r"""
|
||||
|
@ -279,6 +279,40 @@ class PhiDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: PhiConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
PHI_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -324,40 +358,9 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: PhiConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
PHI_INPUTS_DOCSTRING = r"""
|
||||
|
@ -20,6 +20,7 @@ from ..llama.modeling_llama import (
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward, # copied from Llama
|
||||
)
|
||||
@ -170,10 +171,26 @@ class PhiDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class PhiPreTrainedModel(LlamaPreTrainedModel):
|
||||
class PhiRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class PhiPreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class PhiModel(LlamaModel):
|
||||
def __init__(self, config: PhiConfig):
|
||||
super().__init__(config)
|
||||
|
@ -373,6 +373,8 @@ class Phi3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Phi3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Phi3RotaryEmbedding(nn.Module):
|
||||
|
@ -1030,6 +1030,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv):
|
||||
module.b1.data.zero_()
|
||||
module.b2.data.zero_()
|
||||
|
||||
|
||||
def unfold_tensor(tensor, max_seq_len):
|
||||
@ -1607,6 +1610,40 @@ class Phi4MultimodalFeatureEmbedding(nn.Module):
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class Phi4MultimodalRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Phi4MultimodalConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
PHI4_MULTIMODAL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -1653,40 +1690,11 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Phi4MultimodalRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Phi4MultimodalConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
elif isinstance(module, Phi4MultimodalRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Phi4MultimodalImageEmbedding):
|
||||
module.global_img_feature_extensor.data.zero_()
|
||||
module.sub_img_feature_extensor.data.zero_()
|
||||
|
||||
|
||||
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
|
||||
|
@ -40,7 +40,14 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..phi3.configuration_phi3 import Phi3Config
|
||||
from ..phi3.modeling_phi3 import Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
|
||||
from ..phi3.modeling_phi3 import (
|
||||
Phi3DecoderLayer,
|
||||
Phi3ForCausalLM,
|
||||
Phi3Model,
|
||||
Phi3PreTrainedModel,
|
||||
Phi3RMSNorm,
|
||||
Phi3RotaryEmbedding,
|
||||
)
|
||||
from ..siglip.configuration_siglip import SiglipVisionConfig
|
||||
from ..siglip.modeling_siglip import (
|
||||
SiglipEncoder,
|
||||
@ -1133,6 +1140,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv):
|
||||
module.b1.data.zero_()
|
||||
module.b2.data.zero_()
|
||||
|
||||
|
||||
class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
|
||||
@ -1519,6 +1529,28 @@ PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Phi4MultimodalRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Phi4MultimodalImageEmbedding):
|
||||
module.global_img_feature_extensor.data.zero_()
|
||||
module.sub_img_feature_extensor.data.zero_()
|
||||
|
||||
|
||||
class Phi4MultimodalModel(Phi3Model, nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`]
|
||||
@ -1829,7 +1861,7 @@ __all__ = [
|
||||
"Phi4MultimodalAudioModel",
|
||||
"Phi4MultimodalVisionPreTrainedModel",
|
||||
"Phi4MultimodalVisionModel",
|
||||
"Phi4MultimodalPreTrainedModel", # noqa
|
||||
"Phi4MultimodalPreTrainedModel",
|
||||
"Phi4MultimodalModel",
|
||||
"Phi4MultimodalForCausalLM",
|
||||
"Phi4MultimodalVisionConfig",
|
||||
|
@ -923,6 +923,9 @@ class PhimoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
PHIMOE_INPUTS_DOCSTRING = r"""
|
||||
|
@ -383,20 +383,13 @@ class PixtralPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["PixtralAttentionLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.initializer_range
|
||||
)
|
||||
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, PixtralRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
PIXTRAL_INPUTS_DOCSTRING = r"""
|
||||
|
@ -254,15 +254,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class PromptDepthAnythingReassembleLayer(nn.Module):
|
||||
|
@ -210,15 +210,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class PromptDepthAnythingReassembleLayer(nn.Module):
|
||||
|
@ -331,6 +331,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
|
@ -92,6 +92,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -108,6 +109,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig):
|
||||
|
@ -75,6 +75,26 @@ if is_flash_attn_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
Qwen2_5Omni_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -112,7 +132,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
@ -120,6 +140,11 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
|
||||
@ -1102,26 +1127,6 @@ class Qwen2_5OmniMLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniVisionAttention,
|
||||
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
|
||||
|
@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLModel,
|
||||
Qwen2_5_VLPreTrainedModel,
|
||||
Qwen2_5_VLVisionBlock,
|
||||
Qwen2RMSNorm,
|
||||
)
|
||||
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig
|
||||
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
|
||||
@ -130,6 +131,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -145,6 +147,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig):
|
||||
window_size,
|
||||
out_hidden_size,
|
||||
fullatt_block_indexes,
|
||||
initializer_range=initializer_range,
|
||||
**kwargs,
|
||||
)
|
||||
del self.tokens_per_second
|
||||
@ -1027,7 +1030,7 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
@ -1035,6 +1038,11 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
|
||||
|
@ -46,6 +46,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -63,6 +64,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
@ -388,6 +388,8 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
@ -89,6 +89,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -106,6 +107,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(Qwen2VLConfig):
|
||||
@ -224,7 +226,18 @@ class Qwen2_5_VLVisionBlock(nn.Module):
|
||||
|
||||
|
||||
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
|
||||
pass
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
@ -779,6 +779,8 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen2MoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
QWEN2MOE_INPUTS_DOCSTRING = r"""
|
||||
|
@ -38,6 +38,7 @@ class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -52,6 +53,7 @@ class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2VLConfig(PretrainedConfig):
|
||||
|
@ -914,6 +914,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
@ -922,6 +923,11 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Qwen2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
|
@ -358,6 +358,8 @@ class Qwen3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class Qwen3RotaryEmbedding(nn.Module):
|
||||
|
@ -488,6 +488,8 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Qwen3MoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
QWEN3_MOE_INPUTS_DOCSTRING = r"""
|
||||
|
@ -581,6 +581,13 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
||||
if getattr(module, "bias", None) is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
elif isinstance(module, RecurrentGemmaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _setup_cache(self, config, batch, device, dtype):
|
||||
layers = getattr(self, "model", self).layers
|
||||
|
@ -1040,8 +1040,6 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initalize the weights"""
|
||||
|
||||
"""initialize linear layer bias value according to a given probability value."""
|
||||
if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
|
||||
if module.class_embed is not None:
|
||||
for layer in module.class_embed:
|
||||
@ -1055,7 +1053,7 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(layer.layers[-1].weight, 0)
|
||||
nn.init.constant_(layer.layers[-1].bias, 0)
|
||||
|
||||
if isinstance(module, RTDetrMultiscaleDeformableAttention):
|
||||
elif isinstance(module, RTDetrMultiscaleDeformableAttention):
|
||||
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
||||
@ -1078,17 +1076,21 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
|
||||
if isinstance(module, RTDetrModel):
|
||||
elif isinstance(module, RTDetrModel):
|
||||
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
||||
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
||||
nn.init.constant_(module.enc_score_head.bias, bias)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
||||
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
||||
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
||||
|
@ -1314,8 +1314,6 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initalize the weights"""
|
||||
|
||||
"""initialize linear layer bias value according to a given probability value."""
|
||||
if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)):
|
||||
if module.class_embed is not None:
|
||||
for layer in module.class_embed:
|
||||
@ -1329,7 +1327,7 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(layer.layers[-1].weight, 0)
|
||||
nn.init.constant_(layer.layers[-1].bias, 0)
|
||||
|
||||
if isinstance(module, RTDetrV2MultiscaleDeformableAttention):
|
||||
elif isinstance(module, RTDetrV2MultiscaleDeformableAttention):
|
||||
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
||||
@ -1352,17 +1350,21 @@ class RTDetrV2PreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
|
||||
if isinstance(module, RTDetrV2Model):
|
||||
elif isinstance(module, RTDetrV2Model):
|
||||
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
||||
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
||||
nn.init.constant_(module.enc_score_head.bias, bias)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
||||
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
||||
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
||||
|
@ -80,14 +80,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.get_text_config().initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -97,6 +90,9 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class SmolVLMVisionEmbeddings(nn.Module):
|
||||
|
@ -94,7 +94,20 @@ class SmolVLMVisionConfig(Idefics3VisionConfig):
|
||||
|
||||
|
||||
class SmolVLMPreTrainedModel(Idefics3PreTrainedModel):
|
||||
pass
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class SmolVLMVisionTransformer(Idefics3VisionTransformer):
|
||||
|
@ -666,6 +666,9 @@ class StableLmPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
STABLELM_INPUTS_DOCSTRING = r"""
|
||||
|
@ -275,6 +275,40 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class Starcoder2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Starcoder2Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
STARCODER2_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -320,40 +354,9 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Starcoder2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Starcoder2Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
STARCODER2_INPUTS_DOCSTRING = r"""
|
||||
|
@ -41,6 +41,8 @@ from ..mistral.modeling_mistral import (
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
MistralPreTrainedModel,
|
||||
MistralRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
@ -143,6 +145,26 @@ class Starcoder2DecoderLayer(MistralDecoderLayer):
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
|
||||
|
||||
class Starcoder2RotaryEmbedding(MistralRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Starcoder2PreTrainedModel(MistralPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined
|
||||
|
||||
|
||||
|
@ -166,15 +166,6 @@ class UperNetHead(nn.Module):
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
@ -266,15 +257,6 @@ class UperNetFCNHead(nn.Module):
|
||||
|
||||
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# just take the relevant feature maps
|
||||
hidden_states = encoder_hidden_states[self.in_index]
|
||||
@ -296,18 +278,13 @@ class UperNetPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, UperNetPreTrainedModel):
|
||||
module.backbone.init_weights()
|
||||
module.decode_head.init_weights()
|
||||
if module.auxiliary_head is not None:
|
||||
module.auxiliary_head.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights"""
|
||||
self.backbone.init_weights()
|
||||
self.decode_head.init_weights()
|
||||
if self.auxiliary_head is not None:
|
||||
self.auxiliary_head.init_weights()
|
||||
if isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
UPERNET_START_DOCSTRING = r"""
|
||||
|
@ -128,7 +128,6 @@ VIPLLAVA_START_DOCSTRING = r"""
|
||||
"The bare VipLlava Model outputting raw hidden-states without any specific head on top.",
|
||||
VIPLLAVA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->VipLlava,llava->vipllava
|
||||
class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
config_class = VipLlavaConfig
|
||||
base_model_prefix = "model"
|
||||
@ -142,26 +141,15 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of VipLlava isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/vipllava should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
VIPLLAVA_INPUTS_DOCSTRING = r"""
|
||||
|
@ -786,10 +786,14 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, WhisperEncoder):
|
||||
with torch.no_grad():
|
||||
embed_positions = module.embed_positions.weight
|
||||
embed_positions.copy_(sinusoids(*embed_positions.shape))
|
||||
module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape))
|
||||
elif isinstance(module, WhisperForAudioClassification):
|
||||
if self.config.use_weighted_layer_sum:
|
||||
module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
|
@ -850,10 +850,9 @@ class ZambaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, ZambaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ZambaMambaMixer):
|
||||
module.A_log._no_weight_decay = True
|
||||
module.D._no_weight_decay = True
|
||||
|
||||
module.x_proj_weight.data.normal_(mean=0.0, std=std)
|
||||
dt_init_std = self.config.mamba_dt_rank**-0.5
|
||||
nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std)
|
||||
@ -866,10 +865,12 @@ class ZambaPreTrainedModel(PreTrainedModel):
|
||||
).clamp(min=self.config.time_step_floor)
|
||||
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
module.dt_proj_bias.data.copy_(inv_dt)
|
||||
|
||||
with torch.no_grad():
|
||||
module.dt_proj_bias.copy_(inv_dt)
|
||||
module.dt_proj_bias._no_reinit = True
|
||||
A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
|
||||
A = A.expand(module.intermediate_size, -1).contiguous()
|
||||
module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
@classmethod
|
||||
@classmethod
|
||||
|
@ -1225,10 +1225,9 @@ class Zamba2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Zamba2MambaMixer):
|
||||
module.A_log._no_weight_decay = True
|
||||
module.D._no_weight_decay = True
|
||||
|
||||
dt = torch.exp(
|
||||
torch.rand(self.config.n_mamba_heads)
|
||||
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
|
||||
@ -1236,10 +1235,11 @@ class Zamba2PreTrainedModel(PreTrainedModel):
|
||||
).clamp(min=self.config.time_step_floor)
|
||||
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
module.dt_bias.data.copy_(inv_dt)
|
||||
|
||||
with torch.no_grad():
|
||||
module.dt_bias.copy_(inv_dt)
|
||||
module.dt_bias._no_reinit = True
|
||||
A = torch.arange(1, module.num_heads + 1)
|
||||
module.A_log.data.copy_(torch.log(A))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
ZAMBA2_START_DOCSTRING = r"""
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user