Proper map location for optimizer load (#22273)

* Proper map location for optimizer load

* What happened to my code?
This commit is contained in:
Sylvain Gugger 2023-03-20 11:30:46 -04:00 committed by GitHub
parent 786092a35e
commit da005253b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2433,8 +2433,12 @@ class Trainer:
self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu"
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))