[GPT2] Correct gradient checkpointing (#9308)

* correct gpt2

* fix gpt2

* fix use_cache ordering

* correct past tolerance

* fix for all cases

* style
This commit is contained in:
Patrick von Platen 2020-12-25 23:28:12 +01:00 committed by GitHub
parent 21fc676645
commit 61443cd7d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 11 deletions

View File

@ -184,9 +184,9 @@ class Attention(nn.Module):
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
outputs = (torch.matmul(w, v),)
if output_attentions:
outputs.append(w)
outputs += (w,)
return outputs
def merge_heads(self, x):
@ -234,7 +234,7 @@ class Attention(nn.Module):
if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
else:
present = (None,)
present = None
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0]
@ -243,8 +243,7 @@ class Attention(nn.Module):
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
return (a, present) + attn_outputs[1:] # a, present, (attentions)
class MLP(nn.Module):
@ -321,7 +320,11 @@ class Block(nn.Module):
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states
outputs = [hidden_states] + outputs
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel):
output_attentions=output_attentions,
)
hidden_states, present = outputs[:2]
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (present,)
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2],)
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],)
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:

View File

@ -233,6 +233,7 @@ class ModelTesterMixin:
return
config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:

View File

@ -247,7 +247,7 @@ class TFGPT2ModelTester:
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2LMHeadModel(config=config)