fix FSDP version related issues (#22489)

fix fsdp
This commit is contained in:
Sourab Mangrulkar 2023-04-07 04:25:19 +05:30 committed by GitHub
parent c7ec71baf5
commit ee8e80a060
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1481,6 +1481,11 @@ class Trainer:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
signature = inspect.signature(FSDP.__init__).parameters.keys()
kwargs = {}
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
if arg in signature:
kwargs[arg] = getattr(self, arg)
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
@ -1488,9 +1493,7 @@ class Trainer:
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
**kwargs,
)
else:
try: