mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
updating tf 2.0 layer_norm to T5 layer norm
This commit is contained in:
parent
8e651f56b7
commit
608a8f5b56
@ -17,16 +17,11 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from io import open
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
|
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
class TFT5LayerNorm(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, epsilon=1e-6, **kwargs):
|
||||||
|
""" Construct a layernorm module in the T5 style
|
||||||
|
No bias and no substraction of mean.
|
||||||
|
"""
|
||||||
|
super(TFT5LayerNorm, self).__init__(**kwargs)
|
||||||
|
self.variance_epsilon = epsilon
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
"""Build shared word embedding layer """
|
||||||
|
self.weight = self.add_weight(
|
||||||
|
"weight",
|
||||||
|
shape=(input_shape[-1],),
|
||||||
|
initializer='ones')
|
||||||
|
super(TFT5LayerNorm, self).build(input_shape)
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
|
||||||
|
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * x
|
||||||
|
|
||||||
|
|
||||||
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(TFT5DenseReluDense, self).__init__(**kwargs)
|
super(TFT5DenseReluDense, self).__init__(**kwargs)
|
||||||
@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(TFT5LayerFF, self).__init__(**kwargs)
|
super(TFT5LayerFF, self).__init__(**kwargs)
|
||||||
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
|
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
|
||||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||||
name='layer_norm')
|
name='layer_norm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def call(self, hidden_states, training=False):
|
def call(self, hidden_states, training=False):
|
||||||
@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.SelfAttention = TFT5Attention(config,
|
self.SelfAttention = TFT5Attention(config,
|
||||||
has_relative_attention_bias=has_relative_attention_bias,
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
name='SelfAttention')
|
name='SelfAttention')
|
||||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||||
name='layer_norm')
|
name='layer_norm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def call(self, hidden_states, attention_mask=None, position_bias=None,
|
def call(self, hidden_states, attention_mask=None, position_bias=None,
|
||||||
@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
|||||||
self.EncDecAttention = TFT5Attention(config,
|
self.EncDecAttention = TFT5Attention(config,
|
||||||
has_relative_attention_bias=has_relative_attention_bias,
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
name='EncDecAttention')
|
name='EncDecAttention')
|
||||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||||
name='layer_norm')
|
name='layer_norm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
|
def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
|
||||||
@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
has_relative_attention_bias=bool(i == 0),
|
has_relative_attention_bias=bool(i == 0),
|
||||||
name='block_._{}'.format(i))
|
name='block_._{}'.format(i))
|
||||||
for i in range(config.num_layers)]
|
for i in range(config.num_layers)]
|
||||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
|
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
|
||||||
name='final_layer_norm')
|
name='final_layer_norm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
|
Loading…
Reference in New Issue
Block a user