mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Flaubert (#9292)
This commit is contained in:
parent
5dd389d1c7
commit
d735b074d7
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -596,15 +597,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# hidden_states and attentions cannot be None in graph mode.
|
||||
hidden_states = ()
|
||||
attentions = ()
|
||||
hidden_states = () if inputs["output_hidden_states"] else None
|
||||
attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# transformer layers
|
||||
for i in range(self.n_layers):
|
||||
# LayerDrop
|
||||
dropout_probability = tf.random.uniform([1], 0, 1)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
|
||||
if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
@ -642,7 +643,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
if inputs["output_attentions"]:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=inputs["training"])
|
||||
@ -676,10 +677,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# move back sequence length to dimension 0
|
||||
# tensor = tensor.transpose(0, 1)
|
||||
|
||||
# Set to None here if the output booleans are at False
|
||||
hidden_states = hidden_states if inputs["output_hidden_states"] else None
|
||||
attentions = attentions if inputs["output_attentions"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user