Add final_layer_norm to OPT model (#17785)

* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Thomas Wang 2022-06-21 20:26:36 +02:00 committed by GitHub
parent 52404cbad4
commit abc400b06a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 53 additions and 15 deletions

View File

@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig):
ffn_dim=3072,
max_position_embeddings=2048,
do_layer_norm_before=True,
_remove_final_layer_norm=False,
word_embed_proj_dim=None,
dropout=0.1,
attention_dropout=0.0,
@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig):
self.layerdrop = layerdrop
self.use_cache = use_cache
self.do_layer_norm_before = do_layer_norm_before
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
self._remove_final_layer_norm = _remove_final_layer_norm

View File

@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path):
# pop unnecessary weights
keys_to_delete = [
"decoder.version",
"decoder.layer_norm.weight",
"decoder.layer_norm.bias",
"decoder.output_projection.weight",
]
for key in keys_to_delete:
@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path):
keys_to_rename = {
"decoder.project_in_dim.weight": "decoder.project_in.weight",
"decoder.project_out_dim.weight": "decoder.project_out.weight",
"decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
"decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
}
for old_key, new_key in keys_to_rename.items():
if old_key in sd:

View File

@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module):
self.project_in = None
self.project_out = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
else:
self.final_layer_norm = None
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
def __call__(
@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module):
output_hidden_states=output_hidden_states,
)
if self.final_layer_norm is not None:
hidden_state = self.final_layer_norm(hidden_state)
if self.project_out is not None:
hidden_state = self.project_out(hidden_state)

View File

@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel):
else:
self.project_in = None
self.layer_norm = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)

View File

@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer):
name="embed_positions",
)
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
else:
self.final_layer_norm = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if output_attentions:
all_self_attns += (layer_self_attn,)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)

View File

@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
"Today is a beautiful day and I want to",
"In the city of New York, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
]
predicted_outputs = []

View File

@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
"Today is a beautiful day and I want to",
"In the city of New York, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
]
predicted_outputs = []

View File

@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
"Today is a beautiful day and I want to",
"In the city of New York, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
]
predicted_outputs = []