mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
c7ec71baf5
commit
ee8e80a060
@ -1481,6 +1481,11 @@ class Trainer:
|
||||
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
||||
if type(model) != FSDP:
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
signature = inspect.signature(FSDP.__init__).parameters.keys()
|
||||
kwargs = {}
|
||||
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
|
||||
if arg in signature:
|
||||
kwargs[arg] = getattr(self, arg)
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
sharding_strategy=self.fsdp,
|
||||
@ -1488,9 +1493,7 @@ class Trainer:
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=mixed_precision_policy,
|
||||
device_id=self.args.device,
|
||||
backward_prefetch=self.backward_prefetch,
|
||||
forward_prefetch=self.forword_prefetch,
|
||||
limit_all_gathers=self.limit_all_gathers,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user