From 61443cd7d917ef323a799ee27bb4abc4344f0d11 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 25 Dec 2020 23:28:12 +0100 Subject: [PATCH] [GPT2] Correct gradient checkpointing (#9308) * correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style --- src/transformers/models/gpt2/modeling_gpt2.py | 23 +++++++++++-------- tests/test_modeling_common.py | 1 + tests/test_modeling_tf_gpt2.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index f85be6449e9..bb8046c0e2f 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -184,9 +184,9 @@ class Attention(nn.Module): if head_mask is not None: w = w * head_mask - outputs = [torch.matmul(w, v)] + outputs = (torch.matmul(w, v),) if output_attentions: - outputs.append(w) + outputs += (w,) return outputs def merge_heads(self, x): @@ -234,7 +234,7 @@ class Attention(nn.Module): if use_cache is True: present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking else: - present = (None,) + present = None attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) a = attn_outputs[0] @@ -243,8 +243,7 @@ class Attention(nn.Module): a = self.c_proj(a) a = self.resid_dropout(a) - outputs = [a, present] + attn_outputs[1:] - return outputs # a, present, (attentions) + return (a, present) + attn_outputs[1:] # a, present, (attentions) class MLP(nn.Module): @@ -321,7 +320,11 @@ class Block(nn.Module): # residual connection hidden_states = hidden_states + feed_forward_hidden_states - outputs = [hidden_states] + outputs + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + return outputs # hidden_states, present, (attentions, cross_attentions) @@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel): output_attentions=output_attentions, ) - hidden_states, present = outputs[:2] + hidden_states = outputs[0] if use_cache is True: - presents = presents + (present,) + presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2],) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3],) + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 57b421c6108..2b720566539 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -233,6 +233,7 @@ class ModelTesterMixin: return config.gradient_checkpointing = True + config.use_cache = False config.return_dict = True for model_class in self.all_model_classes: diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index d99b76f2f3a..c0781a2947a 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -247,7 +247,7 @@ class TFGPT2ModelTester: output_from_past_slice = output_from_past[:, :, random_slice_idx] # test that outputs are equal for slice - tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = TFGPT2LMHeadModel(config=config)