model_path should be ignored as the checkpoint path (#11157)

* model_path is refered as the path of the trainer, and should be ignored as the checkpoint path.

* Improved according to Sgugger's comment.
This commit is contained in:
Masatoshi TSUCHIYA 2021-04-12 22:06:41 +09:00 committed by GitHub
parent 623cd6aef9
commit ef102c4886
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -332,13 +332,15 @@ def main():
# Training
if training_args.do_train:
checkpoint = None
if last_checkpoint is not None:
model_path = last_checkpoint
checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
# Check the config from that potential checkpoint has the right number of labels before using it as a
# checkpoint.
if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels:
checkpoint = model_args.model_name_or_path
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)