Adding gradient checkpointing to GPT2 (#7446)

* GPT2 gradient checkpointing

* find_unused_parameters removed if checkpointing

* find_unused_parameters removed if checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Added a test for generation with checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven 2020-09-29 18:26:26 +02:00 committed by GitHub
parent 52e8392b7e
commit 9e9a1fb8c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 40 deletions

View File

@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`. :class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
@ -142,6 +144,7 @@ class GPT2Config(PretrainedConfig):
summary_first_dropout=0.1, summary_first_dropout=0.1,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -164,6 +167,7 @@ class GPT2Config(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id

View File

@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch OpenAI GPT-2 model.""" """PyTorch OpenAI GPT-2 model."""
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@ -624,16 +623,35 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block( if getattr(self.config, "gradient_checkpointing", False):
hidden_states,
layer_past=layer_past, def create_custom_forward(module):
attention_mask=attention_mask, def custom_forward(*inputs):
head_mask=head_mask[i], # checkpointing only works with tuple returns, not with lists
encoder_hidden_states=encoder_hidden_states, return tuple(output for output in module(*inputs, use_cache, output_attentions))
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, return custom_forward
output_attentions=output_attentions,
) outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache is True: if use_cache is True:

View File

@ -679,8 +679,10 @@ class Trainer:
model, model,
device_ids=[self.args.local_rank], device_ids=[self.args.local_rank],
output_device=self.args.local_rank, output_device=self.args.local_rank,
find_unused_parameters=True, find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
) )
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
if self.tb_writer is not None: if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_text("args", self.args.to_json_string())

View File

@ -88,7 +88,7 @@ class GPT2ModelTester:
self.bos_token_id = vocab_size - 1 self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
@ -127,6 +127,7 @@ class GPT2ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
return_dict=True, return_dict=True,
gradient_checkpointing=gradient_checkpointing,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@ -269,6 +270,15 @@ class GPT2ModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2LMHeadModel(config)
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def create_and_check_double_lm_head_model( def create_and_check_double_lm_head_model(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
): ):
@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@ -366,33 +380,34 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
class GPT2ModelLanguageGenerationTest(unittest.TestCase): class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_gpt2(self): def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2") for checkpointing in [True, False]:
model.to(torch_device) model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog model.to(torch_device)
expected_output_ids = [ input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
464, expected_output_ids = [
3290, 464,
373, 3290,
1043, 373,
287, 1043,
257, 287,
2214, 257,
1474, 2214,
262, 1474,
16246, 262,
286, 16246,
2688, 286,
290, 2688,
2688, 290,
27262, 2688,
13, 27262,
198, 13,
198, 198,
464, 198,
3290, 464,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog 3290,
output_ids = model.generate(input_ids, do_sample=False) ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow @slow
def test_lm_generate_distilgpt2(self): def test_lm_generate_distilgpt2(self):