mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 13:08:21 +06:00
clean up a little bit PT <=> TF conversion
This commit is contained in:
parent
bebaa14039
commit
f8fb4335c9
@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||
|
||||
pt_model = pt_model_class.from_pretrained(None,
|
||||
config=config,
|
||||
state_dict=torch.load(pytorch_checkpoint_path,
|
||||
map_location='cpu'))
|
||||
pt_model = pt_model_class(config)
|
||||
pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
|
||||
strict-False)
|
||||
pt_model.eval()
|
||||
|
||||
pt_inputs = torch.tensor(inputs_list)
|
||||
with torch.no_grad():
|
||||
pto = pt_model(pt_inputs)
|
||||
|
@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
|
||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
|
||||
if pretrained_model_name_or_path is not None and (
|
||||
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
|
||||
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
||||
"https://github.com/google-research/google-research/issues/119 for more information.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user