mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix FSDP + torch.compile bug when saving pretrained model (#37725)
* args keep_torch_compile=False in _save and _wwrap_method * Fix FSDP execution on evaluation for torch_compile mode * add test trainer FSDP + Torch Compile * fix quality code * make style * Revert " make style" This reverts commit 77e797f8829c50992cc21496be3d9a3e480e1c97. * make style
This commit is contained in:
parent
5534b80b7f
commit
031ef8802c
@ -1986,7 +1986,7 @@ class Trainer:
|
||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||
|
||||
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||
if self.accelerator.unwrap_model(model) is not model:
|
||||
if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
|
||||
return model
|
||||
|
||||
# Mixed precision training with apex
|
||||
@ -3998,8 +3998,8 @@ class Trainer:
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
|
||||
self.accelerator.unwrap_model(self.model).save_pretrained(
|
||||
if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
else:
|
||||
@ -4296,7 +4296,8 @@ class Trainer:
|
||||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
|
||||
if self.is_deepspeed_enabled
|
||||
or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile)
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
|
@ -147,6 +147,34 @@ class TestFSDPTrainerWrap(TestCasePlus):
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
class TestFSDPTrainerTorchCompile(TestCasePlus):
|
||||
@require_torch_multi_accelerator
|
||||
@require_accelerate
|
||||
@run_first
|
||||
def test_trainer(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--use_fsdp",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"--num_processes",
|
||||
f"{backend_device_count(torch_device)}",
|
||||
"--fsdp_transformer_layer_cls_to_wrap",
|
||||
"GPT2Block",
|
||||
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||
"--torch_compile_mode",
|
||||
"default",
|
||||
"--output_dir",
|
||||
f"{output_dir}",
|
||||
"--report_to",
|
||||
"none",
|
||||
]
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
Loading…
Reference in New Issue
Block a user