diff --git a/transformers/modeling_tf_t5.py b/transformers/modeling_tf_t5.py index c1de4745c25..11762ee1e52 100644 --- a/transformers/modeling_tf_t5.py +++ b/transformers/modeling_tf_t5.py @@ -17,16 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import json import logging import math -import os -import sys import copy import itertools -from io import open -import numpy as np import tensorflow as tf from .configuration_t5 import T5Config @@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = { # - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) #################################################### +class TFT5LayerNorm(tf.keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + """ Construct a layernorm module in the T5 style + No bias and no substraction of mean. + """ + super(TFT5LayerNorm, self).__init__(**kwargs) + self.variance_epsilon = epsilon + + def build(self, input_shape): + """Build shared word embedding layer """ + self.weight = self.add_weight( + "weight", + shape=(input_shape[-1],), + initializer='ones') + super(TFT5LayerNorm, self).build(input_shape) + + def call(self, x): + variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True) + x = x * tf.math.rsqrt(variance + self.variance_epsilon) + return self.weight * x + + class TFT5DenseReluDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super(TFT5DenseReluDense, self).__init__(**kwargs) @@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super(TFT5LayerFF, self).__init__(**kwargs) self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense') - self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, - name='layer_norm') + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, + name='layer_norm') self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, training=False): @@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): self.SelfAttention = TFT5Attention(config, has_relative_attention_bias=has_relative_attention_bias, name='SelfAttention') - self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, - name='layer_norm') + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, + name='layer_norm') self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, attention_mask=None, position_bias=None, @@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): self.EncDecAttention = TFT5Attention(config, has_relative_attention_bias=has_relative_attention_bias, name='EncDecAttention') - self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, - name='layer_norm') + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, + name='layer_norm') self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, kv, attention_mask=None, position_bias=None, @@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): has_relative_attention_bias=bool(i == 0), name='block_._{}'.format(i)) for i in range(config.num_layers)] - self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, - name='final_layer_norm') + self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, + name='final_layer_norm') self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def _resize_token_embeddings(self, new_num_tokens):