GPT2DoubleHeadsModel made parallelizable (#10658)

* GPT2DoubleHeadsModel made parallelizeable

* GPT2DoubleHeadsModel added as parallelizeable onto the GPT2 test suite
This commit is contained in:
Igor Shalyminov 2021-03-15 13:10:44 +00:00 committed by GitHub
parent e12d6f513e
commit 505494a86f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

View File

@ -983,6 +983,28 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.multiple_choice_head = self.multiple_choice_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
@ -1096,6 +1118,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)

View File

@ -398,7 +398,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else ()
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
test_missing_keys = False
test_model_parallel = True