Fix eval thread fork bomb (#29538)

* Fix eval thread fork bomb

* Keep eval dl persistent and prepare after so free_memory doesn't destroy it

* Add note

* Quality
This commit is contained in:
Zach Mueller 2024-03-08 11:04:18 -05:00 committed by GitHub
parent 3f6973db06
commit 469c13280d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -888,6 +888,11 @@ class Trainer:
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
# If we have persistent workers, don't do a fork bomb especially as eval datasets
# don't change during training
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
return self.accelerator.prepare(self._eval_dataloader)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator
@ -909,7 +914,13 @@ class Trainer:
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
# accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
self._eval_dataloader = eval_dataloader
return self.accelerator.prepare(eval_dataloader)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""