[GPT2] Correct gradient checkpointing (#9308)

* correct gpt2

* fix gpt2

* fix use_cache ordering

* correct past tolerance

* fix for all cases

* style
This commit is contained in:
Patrick von Platen 2020-12-25 23:28:12 +01:00 committed by GitHub
parent 21fc676645
commit 61443cd7d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 11 deletions

View File

@ -184,9 +184,9 @@ class Attention(nn.Module):
if head_mask is not None: if head_mask is not None:
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)] outputs = (torch.matmul(w, v),)
if output_attentions: if output_attentions:
outputs.append(w) outputs += (w,)
return outputs return outputs
def merge_heads(self, x): def merge_heads(self, x):
@ -234,7 +234,7 @@ class Attention(nn.Module):
if use_cache is True: if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
else: else:
present = (None,) present = None
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0] a = attn_outputs[0]
@ -243,8 +243,7 @@ class Attention(nn.Module):
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:] return (a, present) + attn_outputs[1:] # a, present, (attentions)
return outputs # a, present, (attentions)
class MLP(nn.Module): class MLP(nn.Module):
@ -321,7 +320,11 @@ class Block(nn.Module):
# residual connection # residual connection
hidden_states = hidden_states + feed_forward_hidden_states hidden_states = hidden_states + feed_forward_hidden_states
outputs = [hidden_states] + outputs if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions) return outputs # hidden_states, present, (attentions, cross_attentions)
@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states, present = outputs[:2] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
presents = presents + (present,) presents = presents + (outputs[1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],) all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device # Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel: if self.model_parallel:

View File

@ -233,6 +233,7 @@ class ModelTesterMixin:
return return
config.gradient_checkpointing = True config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:

View File

@ -247,7 +247,7 @@ class TFGPT2ModelTester:
output_from_past_slice = output_from_past[:, :, random_slice_idx] output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice # test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2LMHeadModel(config=config) model = TFGPT2LMHeadModel(config=config)