diff --git a/src/transformers/configuration_gpt2.py b/src/transformers/configuration_gpt2.py index 87588811d96..6142a907376 100644 --- a/src/transformers/configuration_gpt2.py +++ b/src/transformers/configuration_gpt2.py @@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig): :class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`. The dropout ratio to be used after the projection and activation. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -142,6 +144,7 @@ class GPT2Config(PretrainedConfig): summary_first_dropout=0.1, bos_token_id=50256, eos_token_id=50256, + gradient_checkpointing=False, **kwargs ): super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -164,6 +167,7 @@ class GPT2Config(PretrainedConfig): self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels + self.gradient_checkpointing = gradient_checkpointing self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index f1671c69cb2..3ab1660ae4f 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -15,7 +15,6 @@ # limitations under the License. """PyTorch OpenAI GPT-2 model.""" - import os import warnings from dataclasses import dataclass @@ -624,16 +623,35 @@ class GPT2Model(GPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + # checkpointing only works with tuple returns, not with lists + return tuple(output for output in module(*inputs, use_cache, output_attentions)) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states, present = outputs[:2] if use_cache is True: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ea27b1604c8..07ec629d790 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -679,8 +679,10 @@ class Trainer: model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, - find_unused_parameters=True, + find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False), ) + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index dcb0faefe4e..6f550cd75a2 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -88,7 +88,7 @@ class GPT2ModelTester: self.bos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1 - def prepare_config_and_inputs(self): + def prepare_config_and_inputs(self, gradient_checkpointing=False): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -127,6 +127,7 @@ class GPT2ModelTester: bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, return_dict=True, + gradient_checkpointing=gradient_checkpointing, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -269,6 +270,15 @@ class GPT2ModelTester: self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPT2LMHeadModel(config) + model.to(torch_device) + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + def create_and_check_double_lm_head_model( self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args ): @@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) + def test_gpt2_gradient_checkpointing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -366,33 +380,34 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): class GPT2ModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gpt2(self): - model = GPT2LMHeadModel.from_pretrained("gpt2") - model.to(torch_device) - input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog - expected_output_ids = [ - 464, - 3290, - 373, - 1043, - 287, - 257, - 2214, - 1474, - 262, - 16246, - 286, - 2688, - 290, - 2688, - 27262, - 13, - 198, - 198, - 464, - 3290, - ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog - output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + for checkpointing in [True, False]: + model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing) + model.to(torch_device) + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog + expected_output_ids = [ + 464, + 3290, + 373, + 1043, + 287, + 257, + 2214, + 1474, + 262, + 16246, + 286, + 2688, + 290, + 2688, + 27262, + 13, + 198, + 198, + 464, + 3290, + ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) @slow def test_lm_generate_distilgpt2(self):