mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Adding gradient checkpointing to GPT2 (#7446)
* GPT2 gradient checkpointing * find_unused_parameters removed if checkpointing * find_unused_parameters removed if checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Added a test for generation with checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
52e8392b7e
commit
9e9a1fb8c7
@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
|
||||
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
|
||||
|
||||
The dropout ratio to be used after the projection and activation.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
@ -142,6 +144,7 @@ class GPT2Config(PretrainedConfig):
|
||||
summary_first_dropout=0.1,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -164,6 +167,7 @@ class GPT2Config(PretrainedConfig):
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT-2 model."""
|
||||
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
@ -624,16 +623,35 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# checkpointing only works with tuple returns, not with lists
|
||||
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
if use_cache is True:
|
||||
|
@ -679,8 +679,10 @@ class Trainer:
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
output_device=self.args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
|
||||
)
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
||||
|
@ -88,7 +88,7 @@ class GPT2ModelTester:
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
@ -127,6 +127,7 @@ class GPT2ModelTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@ -269,6 +270,15 @@ class GPT2ModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPT2LMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_double_lm_head_model(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
):
|
||||
@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt2_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
@ -366,33 +380,34 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
for checkpointing in [True, False]:
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
|
Loading…
Reference in New Issue
Block a user