Adding gradient checkpointing to GPT2 (#7446)

* GPT2 gradient checkpointing

* find_unused_parameters removed if checkpointing

* find_unused_parameters removed if checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Added a test for generation with checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven 2020-09-29 18:26:26 +02:00 committed by GitHub
parent 52e8392b7e
commit 9e9a1fb8c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 40 deletions

View File

@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
The dropout ratio to be used after the projection and activation.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
@ -142,6 +144,7 @@ class GPT2Config(PretrainedConfig):
summary_first_dropout=0.1,
bos_token_id=50256,
eos_token_id=50256,
gradient_checkpointing=False,
**kwargs
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -164,6 +167,7 @@ class GPT2Config(PretrainedConfig):
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

View File

@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""
import os
import warnings
from dataclasses import dataclass
@ -624,6 +623,25 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists
return tuple(output for output in module(*inputs, use_cache, output_attentions))
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,

View File

@ -679,8 +679,10 @@ class Trainer:
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())

View File

@ -88,7 +88,7 @@ class GPT2ModelTester:
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@ -127,6 +127,7 @@ class GPT2ModelTester:
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
return_dict=True,
gradient_checkpointing=gradient_checkpointing,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@ -269,6 +270,15 @@ class GPT2ModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2LMHeadModel(config)
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def create_and_check_double_lm_head_model(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
):
@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@ -366,7 +380,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
for checkpointing in [True, False]:
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [