mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix loading from tf/pt
This commit is contained in:
parent
a049c8043b
commit
3b7fb48c3b
@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
|
||||
pt_model = BertForSequenceClassification.from_pretrained('./runs/')
|
||||
|
||||
# Quickly inspect a few predictions
|
||||
|
||||
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True)
|
||||
pred = pt_model(torch.tensor([tokens]))
|
||||
|
||||
# Divers
|
||||
import torch
|
||||
|
@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
||||
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME),
|
||||
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
|
@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
||||
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
|
Loading…
Reference in New Issue
Block a user