mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix loading from tf/pt
This commit is contained in:
parent
a049c8043b
commit
3b7fb48c3b
@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
|
|||||||
pt_model = BertForSequenceClassification.from_pretrained('./runs/')
|
pt_model = BertForSequenceClassification.from_pretrained('./runs/')
|
||||||
|
|
||||||
# Quickly inspect a few predictions
|
# Quickly inspect a few predictions
|
||||||
|
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True)
|
||||||
|
pred = pt_model(torch.tensor([tokens]))
|
||||||
|
|
||||||
# Divers
|
# Divers
|
||||||
import torch
|
import torch
|
||||||
|
@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||||
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME),
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
||||||
pretrained_model_name_or_path))
|
pretrained_model_name_or_path))
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
|
@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError("Error no file named {} found in directory {}".format(
|
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
|
||||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
||||||
pretrained_model_name_or_path))
|
pretrained_model_name_or_path))
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
|
Loading…
Reference in New Issue
Block a user