fix loading from tf/pt

This commit is contained in:
thomwolf 2019-09-25 17:46:16 +02:00
parent a049c8043b
commit 3b7fb48c3b
3 changed files with 5 additions and 4 deletions

View File

@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
pt_model = BertForSequenceClassification.from_pretrained('./runs/')
# Quickly inspect a few predictions
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True)
pred = pt_model(torch.tensor([tokens]))
# Divers
import torch

View File

@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {}".format(
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME),
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path

View File

@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {}".format(
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):