mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
updating tf 2.0 layer_norm to T5 layer norm
This commit is contained in:
parent
8e651f56b7
commit
608a8f5b56
@ -17,16 +17,11 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_t5 import T5Config
|
||||
@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
|
||||
####################################################
|
||||
|
||||
class TFT5LayerNorm(tf.keras.layers.Layer):
|
||||
def __init__(self, epsilon=1e-6, **kwargs):
|
||||
""" Construct a layernorm module in the T5 style
|
||||
No bias and no substraction of mean.
|
||||
"""
|
||||
super(TFT5LayerNorm, self).__init__(**kwargs)
|
||||
self.variance_epsilon = epsilon
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
self.weight = self.add_weight(
|
||||
"weight",
|
||||
shape=(input_shape[-1],),
|
||||
initializer='ones')
|
||||
super(TFT5LayerNorm, self).build(input_shape)
|
||||
|
||||
def call(self, x):
|
||||
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
|
||||
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFT5DenseReluDense, self).__init__(**kwargs)
|
||||
@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFT5LayerFF, self).__init__(**kwargs)
|
||||
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
||||
self.SelfAttention = TFT5Attention(config,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
name='SelfAttention')
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, attention_mask=None, position_bias=None,
|
||||
@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
self.EncDecAttention = TFT5Attention(config,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
name='EncDecAttention')
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||
name='layer_norm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
|
||||
@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
has_relative_attention_bias=bool(i == 0),
|
||||
name='block_._{}'.format(i))
|
||||
for i in range(config.num_layers)]
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
||||
name='final_layer_norm')
|
||||
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||
name='final_layer_norm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
|
Loading…
Reference in New Issue
Block a user