updating tf 2.0 layer_norm to T5 layer norm

This commit is contained in:
thomwolf 2019-12-10 10:01:01 +01:00
parent 8e651f56b7
commit 608a8f5b56

View File

@ -17,16 +17,11 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import math
import os
import sys
import copy
import itertools
from io import open
import numpy as np
import tensorflow as tf
from .configuration_t5 import T5Config
@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
####################################################
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super(TFT5LayerNorm, self).__init__(**kwargs)
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer """
self.weight = self.add_weight(
"weight",
shape=(input_shape[-1],),
initializer='ones')
super(TFT5LayerNorm, self).build(input_shape)
def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFT5DenseReluDense, self).__init__(**kwargs)
@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFT5LayerFF, self).__init__(**kwargs)
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False):
@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias,
name='SelfAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, attention_mask=None, position_bias=None,
@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.EncDecAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias,
name='EncDecAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
has_relative_attention_bias=bool(i == 0),
name='block_._{}'.format(i))
for i in range(config.num_layers)]
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='final_layer_norm')
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='final_layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _resize_token_embeddings(self, new_num_tokens):