diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 86091062923..47a4a18d1a1 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,7 +23,7 @@ import numpy as np import torch from torch import Tensor, nn -from ... import AutoBackbone, SwinConfig +from ... import AutoBackbone from ...activations import ACT2FN from ...file_utils import ( ModelOutput, @@ -1388,10 +1388,7 @@ class Mask2FormerPixelLevelModule(nn.Module): """ super().__init__() - backbone_config_dict = config.backbone_config.to_dict() - backbone_config = SwinConfig.from_dict(backbone_config_dict) - - self.encoder = AutoBackbone.from_config(backbone_config) + self.encoder = AutoBackbone.from_config(config.backbone_config) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: