Actual fix (#9787)

This commit is contained in:
Lysandre Debut 2021-01-25 17:12:07 +01:00 committed by GitHub
parent fac7cfb16a
commit 0f443436fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 0 deletions

View File

@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
def get_output_embeddings(self):
return self.lm_head
@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (