[GPTNeo] Fix gradient checkpointing bug (#21733)

* fix bug

* forward contrib credits from discussions

* change logic

---------

Co-authored-by: edbeeching <edbeeching@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-02-23 09:48:19 +01:00 committed by GitHub
parent 36a6a1adb6
commit 78a93d17c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -587,6 +587,13 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
@ -595,11 +602,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):