Fix test for torchneuroncore in Trainer (#22028)

This commit is contained in:
Sylvain Gugger 2023-03-08 09:12:43 -05:00 committed by GitHub
parent de81adf978
commit a5392ee747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1538,7 +1538,7 @@ class Trainer:
if self.args.ddp_bucket_cap_mb is not None: if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available: if is_torch_neuroncore_available():
return model return model
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,