diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 05e22056a51..28eadd3a489 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2773,7 +2773,6 @@ class OneFormerPreTrainedModel(PreTrainedModel): elif isinstance(module, OneFormerTransformerDecoder): nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.constant_(module.query_input_projection.bias, 0) - module.query_input_projection._is_hf_initialized = True elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) @@ -2793,24 +2792,9 @@ class OneFormerPreTrainedModel(PreTrainedModel): nn.init.constant_(module.value_proj.bias.data, 0.0) nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) - elif isinstance(module, OneFormerPixelDecoderEncoderOnly): - for p in module.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) elif isinstance(module, OneFormerPixelDecoder): - for p in module.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) nn.init.normal_(module.level_embed, std=0) - elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer): - for p in module.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p, gain=xavier_std) - elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer): - for p in module.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p, gain=xavier_std) - elif isinstance(module, OneFormerTransformerDecoderFFNLayer): + elif isinstance(module, OneFormerTransformerDecoderLayer): for p in module.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p, gain=xavier_std) @@ -2818,21 +2802,6 @@ class OneFormerPreTrainedModel(PreTrainedModel): for p in module.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p, gain=xavier_std) - elif isinstance(module, OneFormerPixelLevelModule): - for submodule in module.modules(): - if isinstance(submodule, (nn.Conv2d, nn.Linear)): - submodule.weight.data.normal_(mean=0.0, std=std) - if submodule.bias is not None: - submodule.bias.data.zero_() - elif isinstance(module, OneFormerTextContextDecoder): - for submodule in module.modules(): - if isinstance(submodule, nn.Linear): - nn.init.trunc_normal_(submodule.weight, std=0.02) - if isinstance(submodule, nn.Linear) and submodule.bias is not None: - nn.init.constant_(submodule.bias, 0) - elif isinstance(submodule, nn.LayerNorm): - nn.init.constant_(submodule.bias, 0) - nn.init.constant_(submodule.weight, 1.0) elif isinstance(module, OneFormerTextTransformer): proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5) attn_std = module.width**-0.5 @@ -2848,16 +2817,11 @@ class OneFormerPreTrainedModel(PreTrainedModel): if hasattr(module, "reference_points"): nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) nn.init.constant_(module.reference_points.bias.data, 0.0) - elif isinstance(module, OneFormerTaskModel): + elif isinstance(module, OneFormerMLPPredictionHead): for submodule in module.modules(): - if isinstance(module, OneFormerMLPPredictionHead): - for submodule in module.modules(): - if isinstance(submodule, nn.Linear): - nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) - nn.init.constant_(submodule.bias, 0) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + if isinstance(submodule, nn.Linear): + nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) + nn.init.constant_(submodule.bias, 0) elif isinstance(module, nn.MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=std) module.in_proj_bias.data.zero_() @@ -2865,10 +2829,15 @@ class OneFormerPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, OneFormerLoss): + module.logit_scale.data.fill_(np.log(1 / self.config.contrastive_temperature)) @auto_docstring diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py index 0ce791dd3c9..58a93a8c4fa 100644 --- a/tests/models/oneformer/test_modeling_oneformer.py +++ b/tests/models/oneformer/test_modeling_oneformer.py @@ -13,14 +13,13 @@ # limitations under the License. """Testing suite for the PyTorch OneFormer model.""" -import copy import inspect import unittest import numpy as np from tests.test_modeling_common import floats_tensor -from transformers import OneFormerConfig, is_torch_available, is_vision_available +from transformers import AutoModelForImageClassification, OneFormerConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( is_flaky, require_timm, @@ -35,7 +34,7 @@ from transformers.testing_utils import ( from transformers.utils import cached_property from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin +from ...test_modeling_common import ModelTesterMixin, _config_zero_init from ...test_pipeline_mixin import PipelineTesterMixin @@ -51,14 +50,6 @@ if is_vision_available(): from PIL import Image -def _config_zero_init(config): - configs_no_init = copy.deepcopy(config) - for key in configs_no_init.__dict__.keys(): - if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: - setattr(configs_no_init, key, 1e-10) - return configs_no_init - - class OneFormerModelTester: def __init__( self, @@ -375,6 +366,7 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.is_training = True config.contrastive_temperature = 1 configs_no_init = _config_zero_init(config) @@ -382,12 +374,56 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if param.requires_grad: + if ( + "self_attn.sampling_offsets.bias" in name + or "self_attn.value_proj.weight" in name + or "self_attn.output_proj.weight" in name + or "self_attn.in_proj_weight" in name + or "self_attn.out_proj.weight" in name + or "mlp.fc1.weight" in name + or "mlp.fc2.weight" in name + or "text_mapper.text_encoder.positional_embedding" in name + or "text_mapper.text_encoder.token_embedding.weight" in name + ): + continue self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + def test_initialization_pretrained_backbone(self): + backbone_name = "microsoft/resnet-18" + + # load OneFormerConfig config with a pretrained backbone + config = OneFormerConfig( + backbone=backbone_name, + use_pretrained_backbone=True, + ) + + # load pretrained backbone + backbone_model = AutoModelForImageClassification.from_pretrained(backbone_name, device_map=torch_device) + + def params_match(params1, params2): + return all((p1 == p2).all() for p1, p2 in zip(params1, params2)) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "OneFormerModel": + self.assertTrue( + params_match( + backbone_model.base_model.encoder.parameters(), + model.pixel_level_module.encoder.encoder.parameters(), + ) + ) + elif model.__class__.__name__ == "OneFormerForUniversalSegmentation": + self.assertTrue( + params_match( + backbone_model.base_model.encoder.parameters(), + model.model.pixel_level_module.encoder.encoder.parameters(), + ) + ) + def test_training(self): if not self.model_tester.is_training: self.skipTest(reason="model_tester.is_training is set to False")