Fix initialization of OneFormer (#38901)

* fix initialization of OneFormer

* remove redundant initializations

* remove redundant initializations

* remove redundant initializations

* keep BC
This commit is contained in:
BUI Van Tuan 2025-06-27 12:39:37 +02:00 committed by GitHub
parent 540a10848c
commit 371c471113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 52 deletions

View File

@ -2773,7 +2773,6 @@ class OneFormerPreTrainedModel(PreTrainedModel):
elif isinstance(module, OneFormerTransformerDecoder): elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0) nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0) nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
@ -2793,24 +2792,9 @@ class OneFormerPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.value_proj.bias.data, 0.0) nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0) nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, OneFormerPixelDecoderEncoderOnly):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
elif isinstance(module, OneFormerPixelDecoder): elif isinstance(module, OneFormerPixelDecoder):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
nn.init.normal_(module.level_embed, std=0) nn.init.normal_(module.level_embed, std=0)
elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer): elif isinstance(module, OneFormerTransformerDecoderLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderFFNLayer):
for p in module.parameters(): for p in module.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std) nn.init.xavier_uniform_(p, gain=xavier_std)
@ -2818,21 +2802,6 @@ class OneFormerPreTrainedModel(PreTrainedModel):
for p in module.parameters(): for p in module.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std) nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerPixelLevelModule):
for submodule in module.modules():
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
submodule.weight.data.normal_(mean=0.0, std=std)
if submodule.bias is not None:
submodule.bias.data.zero_()
elif isinstance(module, OneFormerTextContextDecoder):
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.trunc_normal_(submodule.weight, std=0.02)
if isinstance(submodule, nn.Linear) and submodule.bias is not None:
nn.init.constant_(submodule.bias, 0)
elif isinstance(submodule, nn.LayerNorm):
nn.init.constant_(submodule.bias, 0)
nn.init.constant_(submodule.weight, 1.0)
elif isinstance(module, OneFormerTextTransformer): elif isinstance(module, OneFormerTextTransformer):
proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5) proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5)
attn_std = module.width**-0.5 attn_std = module.width**-0.5
@ -2848,16 +2817,11 @@ class OneFormerPreTrainedModel(PreTrainedModel):
if hasattr(module, "reference_points"): if hasattr(module, "reference_points"):
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0) nn.init.constant_(module.reference_points.bias.data, 0.0)
elif isinstance(module, OneFormerTaskModel): elif isinstance(module, OneFormerMLPPredictionHead):
for submodule in module.modules(): for submodule in module.modules():
if isinstance(module, OneFormerMLPPredictionHead): if isinstance(submodule, nn.Linear):
for submodule in module.modules(): nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
if isinstance(submodule, nn.Linear): nn.init.constant_(submodule.bias, 0)
nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.MultiheadAttention): elif isinstance(module, nn.MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=std) module.in_proj_weight.data.normal_(mean=0.0, std=std)
module.in_proj_bias.data.zero_() module.in_proj_bias.data.zero_()
@ -2865,10 +2829,15 @@ class OneFormerPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, OneFormerLoss):
module.logit_scale.data.fill_(np.log(1 / self.config.contrastive_temperature))
@auto_docstring @auto_docstring

View File

@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch OneFormer model.""" """Testing suite for the PyTorch OneFormer model."""
import copy
import inspect import inspect
import unittest import unittest
import numpy as np import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available from transformers import AutoModelForImageClassification, OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky, is_flaky,
require_timm, require_timm,
@ -35,7 +34,7 @@ from transformers.testing_utils import (
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin from ...test_modeling_common import ModelTesterMixin, _config_zero_init
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@ -51,14 +50,6 @@ if is_vision_available():
from PIL import Image from PIL import Image
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class OneFormerModelTester: class OneFormerModelTester:
def __init__( def __init__(
self, self,
@ -375,6 +366,7 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.is_training = True
config.contrastive_temperature = 1 config.contrastive_temperature = 1
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
@ -382,12 +374,56 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
if (
"self_attn.sampling_offsets.bias" in name
or "self_attn.value_proj.weight" in name
or "self_attn.output_proj.weight" in name
or "self_attn.in_proj_weight" in name
or "self_attn.out_proj.weight" in name
or "mlp.fc1.weight" in name
or "mlp.fc2.weight" in name
or "text_mapper.text_encoder.positional_embedding" in name
or "text_mapper.text_encoder.token_embedding.weight" in name
):
continue
self.assertIn( self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(), ((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0], [0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
def test_initialization_pretrained_backbone(self):
backbone_name = "microsoft/resnet-18"
# load OneFormerConfig config with a pretrained backbone
config = OneFormerConfig(
backbone=backbone_name,
use_pretrained_backbone=True,
)
# load pretrained backbone
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_name, device_map=torch_device)
def params_match(params1, params2):
return all((p1 == p2).all() for p1, p2 in zip(params1, params2))
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device).eval()
if model.__class__.__name__ == "OneFormerModel":
self.assertTrue(
params_match(
backbone_model.base_model.encoder.parameters(),
model.pixel_level_module.encoder.encoder.parameters(),
)
)
elif model.__class__.__name__ == "OneFormerForUniversalSegmentation":
self.assertTrue(
params_match(
backbone_model.base_model.encoder.parameters(),
model.model.pixel_level_module.encoder.encoder.parameters(),
)
)
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
self.skipTest(reason="model_tester.is_training is set to False") self.skipTest(reason="model_tester.is_training is set to False")