diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b974ad4b200..3c0212e1c8c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1408,8 +1408,7 @@ class T5Model(T5PreTrainedModel): ) hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device)