Fix galore lr display with schedulers (#31710)

* fix galore lr display with lr schedulers

* style

* add some tests to check for displayed lrs

* copy-paste err for warmup steps

* standardize the default lr to be only in the optimizer

* trying out my luck with the reads
This commit is contained in:
Anton Vlasjuk 2024-07-05 19:59:09 +02:00 committed by GitHub
parent ac26260436
commit a01b033cb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 3 deletions

View File

@ -519,7 +519,7 @@ def get_scheduler(
if param.requires_grad:
param.register_post_accumulate_grad_hook(scheduler_hook)
return LayerWiseDummyScheduler()
return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)

View File

@ -27,6 +27,7 @@ import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from itertools import chain
from logging import StreamHandler
from typing import Any, Dict, Iterator, List, Optional, Union
@ -1379,13 +1380,24 @@ class LayerWiseDummyScheduler(LRScheduler):
"""
def __init__(self, *args, **kwargs):
optimizer = LayerWiseDummyOptimizer()
self.default_lr = kwargs["lr"]
optimizer = LayerWiseDummyOptimizer(**kwargs)
last_epoch = -1
verbose = False
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
# default value
lrs = [self.default_lr]
# we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer`
if self.optimizer is not None:
param_wise_lrs = [
[group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values()
]
lrs = list(chain(*param_wise_lrs))
return lrs
def _get_closed_form_lr(self):
return self.base_lrs

View File

@ -1653,6 +1653,84 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
learning_rate = 1e-9
num_steps = 10
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=learning_rate,
logging_steps=5,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# reflects displayed lr in trainer
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_galore_torch
@require_torch_gpu
def test_galore_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
learning_rate = 2e-4
num_train_epochs = 2
num_warmup_steps = 5
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
warmup_steps=num_warmup_steps,
lr_scheduler_type="cosine",
logging_steps=1,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
# increasing and decreasing pattern of lrs
increasing_lrs = [
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
]
decreasing_lrs = [
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
]
self.assertTrue(all(increasing_lrs))
self.assertTrue(all(decreasing_lrs))
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()