mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[GPT2] Correct gradient checkpointing (#9308)
* correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style
This commit is contained in:
parent
21fc676645
commit
61443cd7d9
@ -184,9 +184,9 @@ class Attention(nn.Module):
|
|||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
w = w * head_mask
|
w = w * head_mask
|
||||||
|
|
||||||
outputs = [torch.matmul(w, v)]
|
outputs = (torch.matmul(w, v),)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs.append(w)
|
outputs += (w,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def merge_heads(self, x):
|
def merge_heads(self, x):
|
||||||
@ -234,7 +234,7 @@ class Attention(nn.Module):
|
|||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||||
else:
|
else:
|
||||||
present = (None,)
|
present = None
|
||||||
|
|
||||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
||||||
a = attn_outputs[0]
|
a = attn_outputs[0]
|
||||||
@ -243,8 +243,7 @@ class Attention(nn.Module):
|
|||||||
a = self.c_proj(a)
|
a = self.c_proj(a)
|
||||||
a = self.resid_dropout(a)
|
a = self.resid_dropout(a)
|
||||||
|
|
||||||
outputs = [a, present] + attn_outputs[1:]
|
return (a, present) + attn_outputs[1:] # a, present, (attentions)
|
||||||
return outputs # a, present, (attentions)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -321,7 +320,11 @@ class Block(nn.Module):
|
|||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = hidden_states + feed_forward_hidden_states
|
hidden_states = hidden_states + feed_forward_hidden_states
|
||||||
|
|
||||||
outputs = [hidden_states] + outputs
|
if use_cache:
|
||||||
|
outputs = (hidden_states,) + outputs
|
||||||
|
else:
|
||||||
|
outputs = (hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||||
|
|
||||||
|
|
||||||
@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, present = outputs[:2]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
presents = presents + (present,)
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
|
@ -233,6 +233,7 @@ class ModelTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
config.gradient_checkpointing = True
|
config.gradient_checkpointing = True
|
||||||
|
config.use_cache = False
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
@ -247,7 +247,7 @@ class TFGPT2ModelTester:
|
|||||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||||
|
|
||||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = TFGPT2LMHeadModel(config=config)
|
model = TFGPT2LMHeadModel(config=config)
|
||||||
|
Loading…
Reference in New Issue
Block a user