mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Adding gradient checkpointing to GPT2 (#7446)
* GPT2 gradient checkpointing * find_unused_parameters removed if checkpointing * find_unused_parameters removed if checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Added a test for generation with checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
52e8392b7e
commit
9e9a1fb8c7
@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
|
|||||||
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
|
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
|
||||||
|
|
||||||
The dropout ratio to be used after the projection and activation.
|
The dropout ratio to be used after the projection and activation.
|
||||||
|
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@ -142,6 +144,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
bos_token_id=50256,
|
bos_token_id=50256,
|
||||||
eos_token_id=50256,
|
eos_token_id=50256,
|
||||||
|
gradient_checkpointing=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
@ -164,6 +167,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
self.summary_activation = summary_activation
|
self.summary_activation = summary_activation
|
||||||
self.summary_first_dropout = summary_first_dropout
|
self.summary_first_dropout = summary_first_dropout
|
||||||
self.summary_proj_to_labels = summary_proj_to_labels
|
self.summary_proj_to_labels = summary_proj_to_labels
|
||||||
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch OpenAI GPT-2 model."""
|
"""PyTorch OpenAI GPT-2 model."""
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -624,6 +623,25 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# checkpointing only works with tuple returns, not with lists
|
||||||
|
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
|
@ -679,8 +679,10 @@ class Trainer:
|
|||||||
model,
|
model,
|
||||||
device_ids=[self.args.local_rank],
|
device_ids=[self.args.local_rank],
|
||||||
output_device=self.args.local_rank,
|
output_device=self.args.local_rank,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
|
||||||
)
|
)
|
||||||
|
# find_unused_parameters breaks checkpointing as per
|
||||||
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
|
||||||
if self.tb_writer is not None:
|
if self.tb_writer is not None:
|
||||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
self.tb_writer.add_text("args", self.args.to_json_string())
|
||||||
|
@ -88,7 +88,7 @@ class GPT2ModelTester:
|
|||||||
self.bos_token_id = vocab_size - 1
|
self.bos_token_id = vocab_size - 1
|
||||||
self.eos_token_id = vocab_size - 1
|
self.eos_token_id = vocab_size - 1
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
@ -127,6 +127,7 @@ class GPT2ModelTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||||
@ -269,6 +270,15 @@ class GPT2ModelTester:
|
|||||||
self.parent.assertEqual(result.loss.shape, ())
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
|
model = GPT2LMHeadModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||||
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
result.loss.backward()
|
||||||
|
|
||||||
def create_and_check_double_lm_head_model(
|
def create_and_check_double_lm_head_model(
|
||||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||||
):
|
):
|
||||||
@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_gpt2_gradient_checkpointing(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||||
|
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@ -366,7 +380,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2(self):
|
def test_lm_generate_gpt2(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
for checkpointing in [True, False]:
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
|
Loading…
Reference in New Issue
Block a user