Warning for ALBERT-v2 models

This commit is contained in:
LysandreJik 2019-11-26 10:28:41 -05:00 committed by Lysandre Debut
parent c9cb7f8a0f
commit f873b55e43

View File

@ -315,6 +315,10 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.")
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)