diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 49d120ffae6..0f67bf8f360 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -99,8 +99,23 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): # Naming was changed to be more explicit name = name.replace("embeddings/attention", "embeddings") name = name.replace("inner_group_", "albert_layers/") - name = name.replace("group_", "albert_layer_groups/") + name = name.replace("group_", "albert_layer_groups/") + + # Classifier + if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): + name = "classifier/" + name + + # No ALBERT model currently handles the next sentence prediction task + if "seq_relationship" in name: + continue + name = name.split('/') + + # Ignore the gradients applied by the LAMB/ADAM optimizers. + if "adam_m" in name or "adam_v" in name or "global_step" in name: + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model for m_name in name: if re.fullmatch(r'[A-Za-z]+_\d+', m_name): diff --git a/transformers/modeling_tf_albert.py b/transformers/modeling_tf_albert.py index 164dc743209..d1650d41a83 100644 --- a/transformers/modeling_tf_albert.py +++ b/transformers/modeling_tf_albert.py @@ -31,10 +31,10 @@ import logging logger = logging.getLogger(__name__) TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { - 'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-tf_model.h5", - 'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-tf_model.h5", - 'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-tf_model.h5", - 'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-tf_model.h5", + 'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-tf_model.h5", + 'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-tf_model.h5", + 'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-tf_model.h5", + 'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-tf_model.h5", 'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5", 'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5", 'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5",