mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[fix] make legacy bnb code work (#37331)
* [fix] make legacy bnb code work * [fix] use get with default instead of getter * add test for bnb 8bit optim skip embed * [fix] style * add require annotation of bnb --------- Co-authored-by: jaycha <jaycha@ncsoft.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
5f791281c3
commit
31ea547b7a
@ -1247,7 +1247,7 @@ class Trainer:
|
||||
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
@ -5962,3 +5962,22 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_bnb_8bit_optimizer_skip_embedding(self):
|
||||
model = BasicTextGenerationModel(8, 4)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
report_to="none",
|
||||
optim=name_optim,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args)
|
||||
optimizer = trainer.create_optimizer()
|
||||
modules = optimizer.mng.module_weight_config_triple
|
||||
self.assertNotEqual(len(modules), 0)
|
||||
module, name, config = modules[0]
|
||||
self.assertIsInstance(module, torch.nn.Embedding)
|
||||
self.assertEqual(name, "weight")
|
||||
self.assertDictEqual(config, {"optim_bits": 32})
|
||||
|
Loading…
Reference in New Issue
Block a user