mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Reduce Funnel PT/TF diff (#16744)
* Make Funnel Test less flaky Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0b8f697219
commit
6bed0647fe
@ -84,7 +84,7 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
self.initializer_std = 1.0 if config.initializer_std is None else config.initializer_std
|
||||
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout)
|
||||
@ -94,7 +94,7 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(initializer_range=self.initializer_std),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
@ -65,6 +65,7 @@ class FunnelModelTester:
|
||||
activation_dropout=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=3,
|
||||
initializer_std=0.02, # Set to a smaller value, so we can keep the small error threshold (1e-5) in the test
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
@ -94,6 +95,7 @@ class FunnelModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.initializer_std = initializer_std
|
||||
|
||||
# Used in the tests to check the size of the first attention layer
|
||||
self.num_attention_heads = n_head
|
||||
@ -154,6 +156,7 @@ class FunnelModelTester:
|
||||
activation_dropout=self.activation_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_std=self.initializer_std,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
|
Loading…
Reference in New Issue
Block a user