Fix initialization of OneFormer (#38901)

* fix initialization of OneFormer

* remove redundant initializations

* remove redundant initializations

* remove redundant initializations

* keep BC
This commit is contained in:
BUI Van Tuan 2025-06-27 12:39:37 +02:00 committed by GitHub
parent 540a10848c
commit 371c471113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 52 deletions

View File

@ -2773,7 +2773,6 @@ class OneFormerPreTrainedModel(PreTrainedModel):
elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
@ -2793,24 +2792,9 @@ class OneFormerPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, OneFormerPixelDecoderEncoderOnly):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
elif isinstance(module, OneFormerPixelDecoder):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
nn.init.normal_(module.level_embed, std=0)
elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderFFNLayer):
elif isinstance(module, OneFormerTransformerDecoderLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
@ -2818,21 +2802,6 @@ class OneFormerPreTrainedModel(PreTrainedModel):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerPixelLevelModule):
for submodule in module.modules():
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
submodule.weight.data.normal_(mean=0.0, std=std)
if submodule.bias is not None:
submodule.bias.data.zero_()
elif isinstance(module, OneFormerTextContextDecoder):
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.trunc_normal_(submodule.weight, std=0.02)
if isinstance(submodule, nn.Linear) and submodule.bias is not None:
nn.init.constant_(submodule.bias, 0)
elif isinstance(submodule, nn.LayerNorm):
nn.init.constant_(submodule.bias, 0)
nn.init.constant_(submodule.weight, 1.0)
elif isinstance(module, OneFormerTextTransformer):
proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5)
attn_std = module.width**-0.5
@ -2848,16 +2817,11 @@ class OneFormerPreTrainedModel(PreTrainedModel):
if hasattr(module, "reference_points"):
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
elif isinstance(module, OneFormerTaskModel):
elif isinstance(module, OneFormerMLPPredictionHead):
for submodule in module.modules():
if isinstance(module, OneFormerMLPPredictionHead):
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(submodule, nn.Linear):
nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=std)
module.in_proj_bias.data.zero_()
@ -2865,10 +2829,15 @@ class OneFormerPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, OneFormerLoss):
module.logit_scale.data.fill_(np.log(1 / self.config.contrastive_temperature))
@auto_docstring

View File

@ -13,14 +13,13 @@
# limitations under the License.
"""Testing suite for the PyTorch OneFormer model."""
import copy
import inspect
import unittest
import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available
from transformers import AutoModelForImageClassification, OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
is_flaky,
require_timm,
@ -35,7 +34,7 @@ from transformers.testing_utils import (
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
from ...test_pipeline_mixin import PipelineTesterMixin
@ -51,14 +50,6 @@ if is_vision_available():
from PIL import Image
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class OneFormerModelTester:
def __init__(
self,
@ -375,6 +366,7 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.is_training = True
config.contrastive_temperature = 1
configs_no_init = _config_zero_init(config)
@ -382,12 +374,56 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if (
"self_attn.sampling_offsets.bias" in name
or "self_attn.value_proj.weight" in name
or "self_attn.output_proj.weight" in name
or "self_attn.in_proj_weight" in name
or "self_attn.out_proj.weight" in name
or "mlp.fc1.weight" in name
or "mlp.fc2.weight" in name
or "text_mapper.text_encoder.positional_embedding" in name
or "text_mapper.text_encoder.token_embedding.weight" in name
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_initialization_pretrained_backbone(self):
backbone_name = "microsoft/resnet-18"
# load OneFormerConfig config with a pretrained backbone
config = OneFormerConfig(
backbone=backbone_name,
use_pretrained_backbone=True,
)
# load pretrained backbone
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_name, device_map=torch_device)
def params_match(params1, params2):
return all((p1 == p2).all() for p1, p2 in zip(params1, params2))
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device).eval()
if model.__class__.__name__ == "OneFormerModel":
self.assertTrue(
params_match(
backbone_model.base_model.encoder.parameters(),
model.pixel_level_module.encoder.encoder.parameters(),
)
)
elif model.__class__.__name__ == "OneFormerForUniversalSegmentation":
self.assertTrue(
params_match(
backbone_model.base_model.encoder.parameters(),
model.model.pixel_level_module.encoder.encoder.parameters(),
)
)
def test_training(self):
if not self.model_tester.is_training:
self.skipTest(reason="model_tester.is_training is set to False")