mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add final_layer_norm to OPT model (#17785)
* Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
52404cbad4
commit
abc400b06a
@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig):
|
||||
ffn_dim=3072,
|
||||
max_position_embeddings=2048,
|
||||
do_layer_norm_before=True,
|
||||
_remove_final_layer_norm=False,
|
||||
word_embed_proj_dim=None,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig):
|
||||
self.layerdrop = layerdrop
|
||||
self.use_cache = use_cache
|
||||
self.do_layer_norm_before = do_layer_norm_before
|
||||
|
||||
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
self._remove_final_layer_norm = _remove_final_layer_norm
|
||||
|
@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path):
|
||||
# pop unnecessary weights
|
||||
keys_to_delete = [
|
||||
"decoder.version",
|
||||
"decoder.layer_norm.weight",
|
||||
"decoder.layer_norm.bias",
|
||||
"decoder.output_projection.weight",
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path):
|
||||
keys_to_rename = {
|
||||
"decoder.project_in_dim.weight": "decoder.project_in.weight",
|
||||
"decoder.project_out_dim.weight": "decoder.project_out.weight",
|
||||
"decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
|
||||
"decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
|
||||
}
|
||||
for old_key, new_key in keys_to_rename.items():
|
||||
if old_key in sd:
|
||||
|
@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module):
|
||||
self.project_in = None
|
||||
self.project_out = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
|
||||
|
||||
def __call__(
|
||||
@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module):
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_state = self.final_layer_norm(hidden_state)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_state = self.project_out(hidden_state)
|
||||
|
||||
|
@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
self.layer_norm = None
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
|
@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
name="embed_positions",
|
||||
)
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
|
||||
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
|
||||
@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
|
@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase):
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
"Today is a beautiful day and I want everyone",
|
||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
||||
"Paris is the capital of France and Parisdylib",
|
||||
"Computers and mobile phones have taken precedence over",
|
||||
"Today is a beautiful day and I want to",
|
||||
"In the city of New York, the city",
|
||||
"Paris is the capital of France and the capital",
|
||||
"Computers and mobile phones have taken over the",
|
||||
]
|
||||
|
||||
predicted_outputs = []
|
||||
|
@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase):
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
"Today is a beautiful day and I want everyone",
|
||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
||||
"Paris is the capital of France and Parisdylib",
|
||||
"Computers and mobile phones have taken precedence over",
|
||||
"Today is a beautiful day and I want to",
|
||||
"In the city of New York, the city",
|
||||
"Paris is the capital of France and the capital",
|
||||
"Computers and mobile phones have taken over the",
|
||||
]
|
||||
|
||||
predicted_outputs = []
|
||||
|
@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase):
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
"Today is a beautiful day and I want everyone",
|
||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
||||
"Paris is the capital of France and Parisdylib",
|
||||
"Computers and mobile phones have taken precedence over",
|
||||
"Today is a beautiful day and I want to",
|
||||
"In the city of New York, the city",
|
||||
"Paris is the capital of France and the capital",
|
||||
"Computers and mobile phones have taken over the",
|
||||
]
|
||||
|
||||
predicted_outputs = []
|
||||
|
Loading…
Reference in New Issue
Block a user