mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix importing unofficial TF models with extra optimizer weights
This commit is contained in:
parent
d7dabfeff5
commit
73368963b2
@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
|
|
||||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
if (
|
||||||
|
"adam_m" in name
|
||||||
|
or "adam_v" in name
|
||||||
|
or "AdamWeightDecayOptimizer" in name
|
||||||
|
or "AdamWeightDecayOptimizer_1" in name
|
||||||
|
or "global_step" in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
|
@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
|||||||
name = txt_name.split("/")
|
name = txt_name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
tf_weights.pop(txt_name, None)
|
tf_weights.pop(txt_name, None)
|
||||||
continue
|
continue
|
||||||
|
@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
|
Loading…
Reference in New Issue
Block a user