[fix] make legacy bnb code work (#37331)

* [fix] make legacy bnb code work

* [fix] use get with default instead of getter

* add test for bnb 8bit optim skip embed

* [fix] style

* add require annotation of bnb

---------

Co-authored-by: jaycha <jaycha@ncsoft.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
youngrok cha 2025-04-22 18:17:29 +09:00 committed by GitHub
parent 5f791281c3
commit 31ea547b7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -1247,7 +1247,7 @@ class Trainer:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

View File

@ -5962,3 +5962,22 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
param = next(model.parameters())
group = trainer.get_optimizer_group(param)
self.assertIn(param, group["params"])
@require_bitsandbytes
def test_bnb_8bit_optimizer_skip_embedding(self):
model = BasicTextGenerationModel(8, 4)
with tempfile.TemporaryDirectory() as tmp_dir:
for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]:
args = TrainingArguments(
output_dir=tmp_dir,
report_to="none",
optim=name_optim,
)
trainer = Trainer(model=model, args=args)
optimizer = trainer.create_optimizer()
modules = optimizer.mng.module_weight_config_triple
self.assertNotEqual(len(modules), 0)
module, name, config = modules[0]
self.assertIsInstance(module, torch.nn.Embedding)
self.assertEqual(name, "weight")
self.assertDictEqual(config, {"optim_bits": 32})