mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Actual fix (#9787)
This commit is contained in:
parent
fac7cfb16a
commit
0f443436fb
@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
|
Loading…
Reference in New Issue
Block a user