mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
3960ce917f
commit
bd43151af4
@ -24,7 +24,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import ACT2FN
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, get_initializer, unpack_inputs
|
||||
from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -1069,15 +1075,14 @@ class AdaptiveAveragePooling1D(tf.keras.layers.Layer):
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
class TFSwinModel(TFSwinPreTrainedModel):
|
||||
@keras_serializable
|
||||
class TFSwinMainLayer(tf.keras.layers.Layer):
|
||||
config_class = SwinConfig
|
||||
|
||||
def __init__(
|
||||
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
|
||||
) -> None:
|
||||
super().__init__(config, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||||
@ -1104,15 +1109,6 @@ class TFSwinModel(TFSwinPreTrainedModel):
|
||||
raise NotImplementedError
|
||||
return [None] * len(self.config.depths)
|
||||
|
||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSwinModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
@ -1175,6 +1171,60 @@ class TFSwinModel(TFSwinPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
class TFSwinModel(TFSwinPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
|
||||
) -> None:
|
||||
super().__init__(config, **kwargs)
|
||||
self.config = config
|
||||
self.swin = TFSwinMainLayer(config, name="swin")
|
||||
|
||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSwinModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[tf.Tensor] = None,
|
||||
bool_masked_pos: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
swin_outputs = self.swin(
|
||||
pixel_values=pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
return swin_outputs
|
||||
|
||||
|
||||
class PixelShuffle(tf.keras.layers.Layer):
|
||||
"""TF layer implementation of torch.nn.PixelShuffle"""
|
||||
|
||||
@ -1238,7 +1288,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
|
||||
def __init__(self, config: SwinConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.swin = TFSwinModel(config, add_pooling_layer=False, use_mask_token=True, name="swin")
|
||||
self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin")
|
||||
|
||||
self.decoder = TFSwinDecoder(config, name="decoder")
|
||||
|
||||
@ -1350,7 +1400,7 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.swin = TFSwinModel(config, name="swin")
|
||||
self.swin = TFSwinMainLayer(config, name="swin")
|
||||
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
|
Loading…
Reference in New Issue
Block a user