mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[GPT2] Correct gradient checkpointing (#9308)
* correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style
This commit is contained in:
parent
21fc676645
commit
61443cd7d9
@ -184,9 +184,9 @@ class Attention(nn.Module):
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
|
||||
outputs = [torch.matmul(w, v)]
|
||||
outputs = (torch.matmul(w, v),)
|
||||
if output_attentions:
|
||||
outputs.append(w)
|
||||
outputs += (w,)
|
||||
return outputs
|
||||
|
||||
def merge_heads(self, x):
|
||||
@ -234,7 +234,7 @@ class Attention(nn.Module):
|
||||
if use_cache is True:
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
else:
|
||||
present = (None,)
|
||||
present = None
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
||||
a = attn_outputs[0]
|
||||
@ -243,8 +243,7 @@ class Attention(nn.Module):
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
|
||||
outputs = [a, present] + attn_outputs[1:]
|
||||
return outputs # a, present, (attentions)
|
||||
return (a, present) + attn_outputs[1:] # a, present, (attentions)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -321,7 +320,11 @@ class Block(nn.Module):
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
|
||||
outputs = [hidden_states] + outputs
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2],)
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
|
@ -233,6 +233,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
config.gradient_checkpointing = True
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
@ -247,7 +247,7 @@ class TFGPT2ModelTester:
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||
|
||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
|
Loading…
Reference in New Issue
Block a user