Fix: Suppressed 'use_reentrant=False' warning (#33208)

Co-authored-by: Ankush <ankush13r>
This commit is contained in:
Ankush 2024-09-02 10:16:07 +02:00 committed by GitHub
parent 1ca9ff5c91
commit 409fcfdfcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2118,12 +2118,7 @@ class Trainer:
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
model = self._wrap_model(self.model_wrapped)