mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix eval thread fork bomb (#29538)
* Fix eval thread fork bomb * Keep eval dl persistent and prepare after so free_memory doesn't destroy it * Add note * Quality
This commit is contained in:
parent
3f6973db06
commit
469c13280d
@ -888,6 +888,11 @@ class Trainer:
|
||||
"""
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
||||
# don't change during training
|
||||
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
|
||||
return self.accelerator.prepare(self._eval_dataloader)
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
data_collator = self.data_collator
|
||||
|
||||
@ -909,7 +914,13 @@ class Trainer:
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
||||
# accelerator.free_memory() will destroy the references, so
|
||||
# we need to store the non-prepared version
|
||||
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||
if self.args.dataloader_persistent_workers:
|
||||
self._eval_dataloader = eval_dataloader
|
||||
|
||||
return self.accelerator.prepare(eval_dataloader)
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user