updating tf 2.0 layer_norm to T5 layer norm

This commit is contained in:
thomwolf 2019-12-10 10:01:01 +01:00
parent 8e651f56b7
commit 608a8f5b56

View File

@ -17,16 +17,11 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import math import math
import os
import sys
import copy import copy
import itertools import itertools
from io import open
import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) # - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
#################################################### ####################################################
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super(TFT5LayerNorm, self).__init__(**kwargs)
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer """
self.weight = self.add_weight(
"weight",
shape=(input_shape[-1],),
initializer='ones')
super(TFT5LayerNorm, self).build(input_shape)
def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
class TFT5DenseReluDense(tf.keras.layers.Layer): class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFT5DenseReluDense, self).__init__(**kwargs) super(TFT5DenseReluDense, self).__init__(**kwargs)
@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFT5LayerFF, self).__init__(**kwargs) super(TFT5LayerFF, self).__init__(**kwargs)
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense') self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(config, self.SelfAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=has_relative_attention_bias,
name='SelfAttention') name='SelfAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, attention_mask=None, position_bias=None, def call(self, hidden_states, attention_mask=None, position_bias=None,
@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.EncDecAttention = TFT5Attention(config, self.EncDecAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=has_relative_attention_bias,
name='EncDecAttention') name='EncDecAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, kv, attention_mask=None, position_bias=None, def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
has_relative_attention_bias=bool(i == 0), has_relative_attention_bias=bool(i == 0),
name='block_._{}'.format(i)) name='block_._{}'.format(i))
for i in range(config.num_layers)] for i in range(config.num_layers)]
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='final_layer_norm') name='final_layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):