diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 45c555ee180..5b9ecbeccfa 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -24,7 +24,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import tensorflow as tf from ...activations_tf import ACT2FN -from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, get_initializer, unpack_inputs +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras_serializable, + unpack_inputs, +) from ...tf_utils import shape_list from ...utils import ( ModelOutput, @@ -1069,15 +1075,14 @@ class AdaptiveAveragePooling1D(tf.keras.layers.Layer): return {**base_config, **config} -@add_start_docstrings( - "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", - SWIN_START_DOCSTRING, -) -class TFSwinModel(TFSwinPreTrainedModel): +@keras_serializable +class TFSwinMainLayer(tf.keras.layers.Layer): + config_class = SwinConfig + def __init__( self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs ) -> None: - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) @@ -1104,15 +1109,6 @@ class TFSwinModel(TFSwinPreTrainedModel): raise NotImplementedError return [None] * len(self.config.depths) - @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_FEAT_EXTRACTOR_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSwinModelOutput, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) @unpack_inputs def call( self, @@ -1175,6 +1171,60 @@ class TFSwinModel(TFSwinPreTrainedModel): ) +@add_start_docstrings( + "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, +) +class TFSwinModel(TFSwinPreTrainedModel): + def __init__( + self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs + ) -> None: + super().__init__(config, **kwargs) + self.config = config + self.swin = TFSwinMainLayer(config, name="swin") + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + bool_masked_pos: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + swin_outputs = self.swin( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return swin_outputs + + class PixelShuffle(tf.keras.layers.Layer): """TF layer implementation of torch.nn.PixelShuffle""" @@ -1238,7 +1288,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): def __init__(self, config: SwinConfig): super().__init__(config) - self.swin = TFSwinModel(config, add_pooling_layer=False, use_mask_token=True, name="swin") + self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin") self.decoder = TFSwinDecoder(config, name="decoder") @@ -1350,7 +1400,7 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati super().__init__(config) self.num_labels = config.num_labels - self.swin = TFSwinModel(config, name="swin") + self.swin = TFSwinMainLayer(config, name="swin") # Classifier head self.classifier = (