mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Proper map location for optimizer load (#22273)
* Proper map location for optimizer load * What happened to my code?
This commit is contained in:
parent
786092a35e
commit
da005253b8
@ -2433,8 +2433,12 @@ class Trainer:
|
||||
|
||||
self.model_wrapped.register_post_step_hook(opt_load_hook)
|
||||
else:
|
||||
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
|
||||
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
|
||||
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
|
||||
map_location = self.args.device if self.args.world_size > 1 else "cpu"
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
|
Loading…
Reference in New Issue
Block a user