Fix check for backword_pos (#23075)

This commit is contained in:
Wing Lian 2023-05-02 09:32:42 -04:00 committed by GitHub
parent f31a510bb3
commit c6c6658499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -458,7 +458,9 @@ class Trainer:
self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.forward_prefetch = False