mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Use class decorator instead of superclass
When supplied by Keras deserialization, the config parameter to initializers will be a dict. So intercept it and convert to PretrainedConfig object (and store in instance attribute for get_config to get at it) before passing to the actual initializer. To accomplish this, and repeat as little code as possible, use a class decorator on TF*MainLayer classes.
This commit is contained in:
parent
b8da16f390
commit
0c716ede8c
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||||
@ -100,6 +101,20 @@ class AutoConfig:
|
|||||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_class_for_model_class(cls, model_class):
|
||||||
|
module = import_module(model_class.__module__)
|
||||||
|
return next(
|
||||||
|
(
|
||||||
|
module_attribute
|
||||||
|
for module_attribute_name in dir(module)
|
||||||
|
if module_attribute_name.endswith("Config")
|
||||||
|
for module_attribute in (getattr(module, module_attribute_name),)
|
||||||
|
if issubclass(module_attribute, PretrainedConfig)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_model(cls, model_type, *args, **kwargs):
|
def for_model(cls, model_type, *args, **kwargs):
|
||||||
for pattern, config_class in CONFIG_MAPPING.items():
|
for pattern, config_class in CONFIG_MAPPING.items():
|
||||||
|
@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFAlbertMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
||||||
|
@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
class TFBertMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||||
|
@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TFCTRLMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_past = config.output_past
|
self.output_past = config.output_past
|
||||||
|
@ -24,7 +24,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFDistilBertMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
||||||
|
@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config
|
|||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
TFMainLayer,
|
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, present, (attentions)
|
return outputs # x, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFGPT2MainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
|
@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig
|
|||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
TFMainLayer,
|
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, (attentions)
|
return outputs # x, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFOpenAIGPTMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
|
@ -20,10 +20,11 @@ import logging
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from . import PretrainedConfig
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -25,7 +25,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
|
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||||||
# The full model without a specific pretrained or finetuning head is
|
# The full model without a specific pretrained or finetuning head is
|
||||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||||
####################################################
|
####################################################
|
||||||
class TFT5MainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFT5MainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
inputs,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
hidden_states = inputs[0]
|
||||||
|
assert len(inputs) <= 1, "Too many inputs."
|
||||||
|
elif isinstance(inputs, dict):
|
||||||
|
hidden_states = inputs["hidden_states"]
|
||||||
|
assert len(inputs) <= 1, "Too many inputs."
|
||||||
|
else:
|
||||||
|
hidden_states = inputs
|
||||||
batch_size, seq_length = shape_list(hidden_states)[:2]
|
batch_size, seq_length = shape_list(hidden_states)[:2]
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.fill((batch_size, seq_length), 1)
|
attention_mask = tf.fill((batch_size, seq_length), 1)
|
||||||
|
@ -24,7 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
class TFTransfoXLMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
|
@ -47,21 +47,31 @@ class TFModelUtilsMixin:
|
|||||||
return self.count_params()
|
return self.count_params()
|
||||||
|
|
||||||
|
|
||||||
class TFMainLayer(tf.keras.layers.Layer):
|
def keras_serializable(cls):
|
||||||
"""
|
initializer = cls.__init__
|
||||||
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def wrapped_init(self, config, *args, **kwargs):
|
||||||
super().__init__(**kwargs)
|
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
config = PretrainedConfig.from_dict(config)
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
|
||||||
|
initializer(self, config, *args, **kwargs)
|
||||||
self._transformers_config = config
|
self._transformers_config = config
|
||||||
|
|
||||||
def get_config(self):
|
cls.__init__ = wrapped_init
|
||||||
cfg = super().get_config()
|
|
||||||
cfg["config"] = self._transformers_config.to_dict()
|
if not hasattr(cls, "get_config"):
|
||||||
return cfg
|
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
|
||||||
|
if hasattr(cls.get_config, "_is_default"):
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
cfg = super(cls, self).get_config()
|
||||||
|
cfg["config"] = self._transformers_config.to_dict()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
cls.get_config = get_config
|
||||||
|
|
||||||
|
return tf.keras.utils.register_keras_serializable()(cls)
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||||
|
@ -26,11 +26,11 @@ import tensorflow as tf
|
|||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFMainLayer,
|
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TFXLMMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
|
@ -25,11 +25,11 @@ import tensorflow as tf
|
|||||||
from .configuration_xlnet import XLNetConfig
|
from .configuration_xlnet import XLNetConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFMainLayer,
|
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFXLNetMainLayer(TFMainLayer):
|
@keras_serializable
|
||||||
|
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_past = config.output_past
|
self.output_past = config.output_past
|
||||||
|
@ -22,7 +22,6 @@ import unittest
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
from transformers.modeling_tf_utils import TFMainLayer
|
|
||||||
|
|
||||||
from .utils import _tf_gpu_memory_limit, require_tf
|
from .utils import _tf_gpu_memory_limit, require_tf
|
||||||
|
|
||||||
@ -90,6 +89,7 @@ class TFModelTesterMixin:
|
|||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
after_outputs = model(inputs_dict)
|
after_outputs = model(inputs_dict)
|
||||||
|
|
||||||
self.assert_outputs_same(after_outputs, outputs)
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
def test_keras_save_load(self):
|
def test_keras_save_load(self):
|
||||||
@ -100,10 +100,14 @@ class TFModelTesterMixin:
|
|||||||
for model_class in self.all_model_classes
|
for model_class in self.all_model_classes
|
||||||
for module in (import_module(model_class.__module__),)
|
for module in (import_module(model_class.__module__),)
|
||||||
for module_member_name in dir(module)
|
for module_member_name in dir(module)
|
||||||
|
if module_member_name.endswith("MainLayer")
|
||||||
for module_member in (getattr(module, module_member_name),)
|
for module_member in (getattr(module, module_member_name),)
|
||||||
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__
|
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
||||||
)
|
)
|
||||||
for main_layer_class in tf_main_layer_classes:
|
for main_layer_class in tf_main_layer_classes:
|
||||||
|
if main_layer_class.__name__ == "TFT5MainLayer":
|
||||||
|
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
|
||||||
|
continue
|
||||||
main_layer = main_layer_class(config)
|
main_layer = main_layer_class(config)
|
||||||
symbolic_inputs = {
|
symbolic_inputs = {
|
||||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||||
@ -125,6 +129,7 @@ class TFModelTesterMixin:
|
|||||||
# Make sure we don't have nans
|
# Make sure we don't have nans
|
||||||
out_1 = after_outputs[0].numpy()
|
out_1 = after_outputs[0].numpy()
|
||||||
out_2 = outputs[0].numpy()
|
out_2 = outputs[0].numpy()
|
||||||
|
self.assertEqual(out_1.shape, out_2.shape)
|
||||||
out_1 = out_1[~np.isnan(out_1)]
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
out_2 = out_2[~np.isnan(out_2)]
|
out_2 = out_2[~np.isnan(out_2)]
|
||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
|
Loading…
Reference in New Issue
Block a user