mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
GPT2DoubleHeadsModel made parallelizable (#10658)
* GPT2DoubleHeadsModel made parallelizeable * GPT2DoubleHeadsModel added as parallelizeable onto the GPT2 test suite
This commit is contained in:
parent
e12d6f513e
commit
505494a86f
@ -983,6 +983,28 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.transformer.h))
|
||||
self.transformer.parallelize(self.device_map)
|
||||
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
||||
self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
|
||||
self.model_parallel = True
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.transformer.deparallelize()
|
||||
self.transformer = self.transformer.to("cpu")
|
||||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.multiple_choice_head = self.multiple_choice_head.to("cpu")
|
||||
self.model_parallel = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@ -1096,6 +1118,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
||||
|
||||
|
@ -398,7 +398,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
test_missing_keys = False
|
||||
test_model_parallel = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user