mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
FEAT [Trainer
/ bnb
]: Add RMSProp from bitsandbytes
to HF Trainer
(#29082)
* add RMSProp to Trainer * revert some change * Update src/transformers/trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
a7ff2f23a0
commit
f7ef7cec6c
@ -1084,9 +1084,12 @@ class Trainer:
|
|||||||
OptimizerNames.LION_8BIT,
|
OptimizerNames.LION_8BIT,
|
||||||
OptimizerNames.PAGED_LION,
|
OptimizerNames.PAGED_LION,
|
||||||
OptimizerNames.PAGED_LION_8BIT,
|
OptimizerNames.PAGED_LION_8BIT,
|
||||||
|
OptimizerNames.RMSPROP_BNB,
|
||||||
|
OptimizerNames.RMSPROP_8BIT,
|
||||||
|
OptimizerNames.RMSPROP_32BIT,
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
from bitsandbytes.optim import AdamW, Lion
|
from bitsandbytes.optim import AdamW, Lion, RMSprop
|
||||||
|
|
||||||
is_paged = False
|
is_paged = False
|
||||||
optim_bits = 32
|
optim_bits = 32
|
||||||
@ -1101,8 +1104,16 @@ class Trainer:
|
|||||||
elif "lion" in args.optim:
|
elif "lion" in args.optim:
|
||||||
optimizer_cls = Lion
|
optimizer_cls = Lion
|
||||||
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
|
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
|
||||||
|
elif "rmsprop" in args.optim:
|
||||||
|
optimizer_cls = RMSprop
|
||||||
|
# Above we pass all `adam_kwargs` to the optimizer, here
|
||||||
|
# we only pass `optim_args` which can be passed by the user.
|
||||||
|
additional_optim_kwargs = optim_args
|
||||||
|
|
||||||
|
bnb_kwargs = {"optim_bits": optim_bits}
|
||||||
|
if "rmsprop" not in args.optim:
|
||||||
|
bnb_kwargs["is_paged"] = is_paged
|
||||||
|
|
||||||
bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
|
|
||||||
optimizer_kwargs.update(additional_optim_kwargs)
|
optimizer_kwargs.update(additional_optim_kwargs)
|
||||||
optimizer_kwargs.update(bnb_kwargs)
|
optimizer_kwargs.update(bnb_kwargs)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -157,6 +157,9 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
PAGED_LION = "paged_lion_32bit"
|
PAGED_LION = "paged_lion_32bit"
|
||||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||||
RMSPROP = "rmsprop"
|
RMSPROP = "rmsprop"
|
||||||
|
RMSPROP_BNB = "rmsprop_bnb"
|
||||||
|
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||||
|
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||||
|
|
||||||
|
|
||||||
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
||||||
|
@ -58,6 +58,7 @@ from transformers.testing_utils import (
|
|||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_bitsandbytes,
|
||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
@ -872,6 +873,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, 10)
|
self.assertEqual(train_output.global_step, 10)
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_rmsprop_bnb(self):
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb"
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check that it trains without errors
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_rmsprop_bnb_8bit(self):
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_8bit"
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check that it trains without errors
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_rmsprop_bnb_32bit(self):
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_32bit"
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check that it trains without errors
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
def test_neftune(self):
|
def test_neftune(self):
|
||||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
Loading…
Reference in New Issue
Block a user