model_path should be ignored as the checkpoint path (#11157)

* model_path is refered as the path of the trainer, and should be ignored as the checkpoint path.

* Improved according to Sgugger's comment.
This commit is contained in:
Masatoshi TSUCHIYA 2021-04-12 22:06:41 +09:00 committed by GitHub
parent 623cd6aef9
commit ef102c4886
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -332,13 +332,15 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
checkpoint = None
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path # Check the config from that potential checkpoint has the right number of labels before using it as a
else: # checkpoint.
model_path = None if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels:
train_result = trainer.train(model_path=model_path) checkpoint = model_args.model_name_or_path
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics metrics = train_result.metrics
max_train_samples = ( max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)