mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Generalize decay_mask_fn to apply mask to all LayerNorm params (#18273)
* generalize decay_mask_fn to find all layernorm params * fixup * generalising decay_mask_fn
This commit is contained in:
parent
83d2d74509
commit
170fcaa604
@ -875,15 +875,19 @@ def main():
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBart.
|
||||
# For FlaxT5, one should correct the layer norm parameter naming
|
||||
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
layer_norm_params = [
|
||||
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
@ -638,15 +638,19 @@ def main():
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxGPT2.
|
||||
# For other models, one should correct the layer norm parameter naming
|
||||
# accordingly.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {
|
||||
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
|
||||
for path in flat_params
|
||||
}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
@ -658,12 +658,19 @@ def main():
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
||||
# For other models, one should correct the layer norm parameter naming
|
||||
# accordingly.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
@ -326,7 +326,6 @@ class FlaxDataCollatorForT5MLM:
|
||||
decoder_start_token_id: int
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
||||
|
||||
# convert list to dict and tensorize input
|
||||
batch = BatchEncoding(
|
||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||
@ -395,7 +394,6 @@ class FlaxDataCollatorForT5MLM:
|
||||
return input_ids
|
||||
|
||||
def random_spans_noise_mask(self, length):
|
||||
|
||||
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
@ -782,10 +780,17 @@ def main():
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {
|
||||
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
|
||||
for path in flat_params
|
||||
}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
@ -327,12 +327,19 @@ def create_train_state(
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
||||
# For other models, one should correct the layer norm parameter naming
|
||||
# accordingly.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
tx = optax.adamw(
|
||||
|
@ -723,15 +723,19 @@ def main():
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBart.
|
||||
# For FlaxT5, one should correct the layer norm parameter naming
|
||||
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
layer_norm_params = [
|
||||
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
@ -226,7 +226,17 @@ def create_train_state(
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
tx = optax.adamw(
|
||||
|
@ -284,12 +284,19 @@ def create_train_state(
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
||||
# For other models, one should correct the layer norm parameter naming
|
||||
# accordingly.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
tx = optax.adamw(
|
||||
|
Loading…
Reference in New Issue
Block a user