fix: prevent model access error during Optuna hyperparameter tuning (#36395)

* fix: prevent model access error during Optuna hyperparameter tuning

The `transformers.integrations.integration_utils.run_hp_search_optuna` function releases model memory and sets trainer.model to None after each trial. This causes an AttributeError when  subsequent Trainer.train calls attempt to access the model before reinitialization. This is only an issue when `fp16_full_eval` or `bf16_full_eval` flags are enabled.

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Manny Cortes 2025-02-26 08:06:48 -08:00 committed by GitHub
parent 6513e5e402
commit 082834dd79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 1 deletions

View File

@ -2180,7 +2180,12 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel:
if (
(args.fp16_full_eval or args.bf16_full_eval)
and not args.do_train
and not self.is_model_parallel
and self.model_init is None
):
self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs:

View File

@ -4998,6 +4998,38 @@ class TrainerHyperParameterMultiObjectOptunaIntegrationTest(unittest.TestCase):
)
@require_torch
@require_optuna
class TrainerHyperParameterOptunaIntegrationTestWithFullEval(unittest.TestCase):
def test_hyperparameter_search(self):
def hp_space(trial):
return {}
def model_init(trial):
if trial is not None:
a = trial.suggest_int("a", -4, 4)
b = trial.suggest_int("b", -4, 4)
else:
a = 0
b = 0
config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
disable_tqdm=True,
model_init=model_init,
fp16_full_eval=True,
)
trainer.hyperparameter_search(
direction="minimize",
hp_space=hp_space,
n_trials=2,
)
@require_torch
@require_ray
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):