mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix ALBERT exports with pretraining + sp classifier; Fix naming for ALBERT TF models
This commit is contained in:
parent
b3d834ae11
commit
e85855f2c4
@ -99,8 +99,23 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
# Naming was changed to be more explicit
|
||||
name = name.replace("embeddings/attention", "embeddings")
|
||||
name = name.replace("inner_group_", "albert_layers/")
|
||||
name = name.replace("group_", "albert_layer_groups/")
|
||||
name = name.replace("group_", "albert_layer_groups/")
|
||||
|
||||
# Classifier
|
||||
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
|
||||
name = "classifier/" + name
|
||||
|
||||
# No ALBERT model currently handles the next sentence prediction task
|
||||
if "seq_relationship" in name:
|
||||
continue
|
||||
|
||||
name = name.split('/')
|
||||
|
||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
|
@ -31,10 +31,10 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-tf_model.h5",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-tf_model.h5",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-tf_model.h5",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-tf_model.h5",
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-tf_model.h5",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-tf_model.h5",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-tf_model.h5",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-tf_model.h5",
|
||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5",
|
||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5",
|
||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5",
|
||||
|
Loading…
Reference in New Issue
Block a user