mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove dropout in embedding layer of OPT (#18845)
This commit is contained in:
parent
367026000b
commit
adbf3a40de
@ -484,8 +484,6 @@ class FlaxOPTDecoder(nn.Module):
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
hidden_state, all_hidden_states, attentions = self.layers(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -637,7 +637,6 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -652,7 +652,6 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
Loading…
Reference in New Issue
Block a user