diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 555a8eb0f58..bd8435365c8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1481,6 +1481,11 @@ class Trainer: mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) if type(model) != FSDP: # XXX: Breaking the self.model convention but I see no way around it for now. + signature = inspect.signature(FSDP.__init__).parameters.keys() + kwargs = {} + for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: + if arg in signature: + kwargs[arg] = getattr(self, arg) self.model = model = FSDP( model, sharding_strategy=self.fsdp, @@ -1488,9 +1493,7 @@ class Trainer: auto_wrap_policy=auto_wrap_policy, mixed_precision=mixed_precision_policy, device_id=self.args.device, - backward_prefetch=self.backward_prefetch, - forward_prefetch=self.forword_prefetch, - limit_all_gathers=self.limit_all_gathers, + **kwargs, ) else: try: