shift torch dynamo handling to accelerate (#23168)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate
This commit is contained in:
Sourab Mangrulkar 2023-05-31 14:42:07 +05:30 committed by GitHub
parent 0b774074a5
commit 03db591047
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -1559,11 +1559,6 @@ class Trainer:
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
return model
def train(

View File

@ -1371,6 +1371,15 @@ class TrainingArguments:
self.torch_compile = True
if self.torch_compile and self.torch_compile_backend is None:
self.torch_compile_backend = "inductor"
# accelerate integration for torch compile
if self.torch_compile:
# set env vars for accelerate
prefix = "ACCELERATE_DYNAMO_"
os.environ[prefix + "BACKEND"] = self.torch_compile_backend
if self.torch_compile_mode is not None:
os.environ[prefix + "MODE"] = self.torch_compile_mode
if self.framework == "pt" and is_torch_available() and self.torch_compile:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16: