Add some warning for Dynamo and enable TF32 when it's set (#20515)

This commit is contained in:
Sylvain Gugger 2022-11-30 15:42:17 -05:00 committed by GitHub
parent 68cfffc4b4
commit e342ac7e03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1148,6 +1148,15 @@ class TrainingArguments:
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
torch.backends.cuda.matmul.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
)
if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
if self.tf32:
if is_torch_tf32_available():