mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
make style
This commit is contained in:
parent
56d4ba8ddb
commit
801aaa5508
@ -228,8 +228,10 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if args.model_name_or_path and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
if (
|
||||
args.model_name_or_path
|
||||
and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
|
||||
and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
@ -587,9 +589,7 @@ def main():
|
||||
if args.should_continue:
|
||||
sorted_checkpoints = _sorted_checkpoints(args)
|
||||
if len(sorted_checkpoints) == 0:
|
||||
raise ValueError(
|
||||
"Used --should_continue but no checkpoint was found in --output_dir."
|
||||
)
|
||||
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
|
||||
else:
|
||||
args.model_name_or_path = sorted_checkpoints[-1]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user