fix: Avoid error when fsdp_config is missing xla_fsdp_v2 (#29480)

Signed-off-by: Ashok Pon Kumar Sree Prakash <ashokponkumar@gmail.com>
This commit is contained in:
Ashok Pon Kumar 2024-03-07 17:14:23 +05:30 committed by GitHub
parent f6133d767a
commit 9288e759ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -647,7 +647,7 @@ class Trainer:
if args.torch_compile and not is_torch_compile_available(): if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"] self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled: if self.is_fsdp_xla_v2_enabled:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2. # Tensor axis is just a placeholder where it will not be used in FSDPv2.