diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index 3e3f4dd84a8..31097d9c01a 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "prophetnet" keys_to_ignore_at_inference = ["past_key_values"] @@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig): num_buckets=32, relative_max_distance=128, disable_ngram_loss=False, + gradient_checkpointing=False, eps=0.0, use_cache=True, pad_token_id=0, @@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig): self.use_cache = use_cache + # 4 Training Args (should be removed soon) + self.gradient_checkpointing = gradient_checkpointing + @property def num_attention_heads(self) -> int: return self.num_encoder_attention_heads diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 704e86059c6..03aac1bd899 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -18,7 +18,7 @@ import copy import math import warnings from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -567,6 +567,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): """ def __init__(self, config: ProphetNetConfig): + self.max_length = config.max_position_embeddings super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None): @@ -578,7 +579,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): if past_key_values is not None: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX - prev_num_input_ids = past_key_values[0]["self"]["prev_key_states"].shape[2] + prev_num_input_ids = past_key_values[0][0].shape[2] num_input_ids = inputs_shape[1] + prev_num_input_ids position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( int(self.padding_idx + num_input_ids) @@ -592,6 +593,9 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask ).long() + self.padding_idx + # make sure position_ids are not bigger then max_length + position_ids = position_ids.clamp(0, self.max_length - 1) + return super().forward(position_ids), position_ids def _forward(self, position_ids): @@ -624,66 +628,65 @@ class ProphetNetAttention(nn.Module): self.out_proj = nn.Linear(hidden_size, hidden_size) - def _reshape(self, tensor, first_dim, batch_size): - return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, - layer_state: Optional[Dict[str, Optional[Tensor]]] = None, + past_key_value: Optional[Tuple[Tensor]] = None, + output_attentions: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: - sequence_length, batch_size, hidden_size = hidden_states.size() + batch_size, tgt_len, hidden_size = hidden_states.size() # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - cache_key = "cross_attention" if is_cross_attention else "self" assert list(hidden_states.size()) == [ - sequence_length, batch_size, + tgt_len, hidden_size, - ], f"Size of hidden states should be {sequence_length, batch_size, hidden_size}, but is {hidden_states.size()}" + ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}" # previous time steps are cached - no need to recompute key and value if they are static - if layer_state is not None: - saved_state = layer_state.get(cache_key, None) - query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5) - query_states = self._reshape(query_states, sequence_length, batch_size) - if not is_cross_attention: - # self-attention - key_states = self.key_proj(hidden_states) - key_states = self._reshape(key_states, -1, batch_size) - value_states = self.value_proj(hidden_states) - value_states = self._reshape(value_states, -1, batch_size) - elif saved_state is None: - # cross-attention without layer state - key_states = self.key_proj(key_value_states) - key_states = self._reshape(key_states, -1, batch_size) - value_states = self.value_proj(key_value_states) - value_states = self._reshape(value_states, -1, batch_size) + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) else: - key_states = saved_state["prev_key_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim) - value_states = saved_state["prev_value_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim) + # self_attention + key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) - # Update cache if is_cross_attention: - layer_state[cache_key] = { - "prev_key_states": key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - "prev_value_states": value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - } + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) - key_sequence_length = key_states.size(1) + # project states into the correct shape + proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) assert attn_weights.size() == ( batch_size * self.num_attn_heads, - sequence_length, - key_sequence_length, - ), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, sequence_length, key_sequence_length}, but is of size {attn_weights.shape}" + tgt_len, + src_len, + ), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}" # This is part of a workaround to get around fork/join parallelism not supporting Optional types. if attention_mask is not None and attention_mask.dim() == 0: @@ -691,19 +694,21 @@ class ProphetNetAttention(nn.Module): assert attention_mask is None or attention_mask.size() == ( self.num_attn_heads * batch_size, 1, - key_sequence_length, - ), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, key_sequence_length}, but is {attention_mask.shape}" + src_len, + ), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}" if attention_mask is not None: # don't attend to padding symbols attn_weights = attn_weights + attention_mask - # need two reshapes to keep gradient at attention weights - attn_weights_reshaped = attn_weights.view( - batch_size, self.num_attn_heads, sequence_length, key_sequence_length - ) - attn_weights = attn_weights_reshaped.view( - batch_size * self.num_attn_heads, sequence_length, key_sequence_length - ) + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None attn_weights = F.softmax(attn_weights, dim=-1) attn_probs = F.dropout( @@ -715,15 +720,20 @@ class ProphetNetAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) assert attn_output.size() == ( batch_size * self.num_attn_heads, - sequence_length, + tgt_len, self.head_dim, - ), "`attn_output` should be of shape {batch_size * self.num_attn_heads, sequence_length, self.head_dim}, but is of shape {attn_output.size()}" - attn_output = attn_output.transpose(0, 1).contiguous().view(sequence_length, batch_size, hidden_size) + ), "`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}" + + attn_output = ( + attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(batch_size, tgt_len, hidden_size) + ) attn_output = self.out_proj(attn_output) attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, attn_weights_reshaped + return attn_output, attn_weights_reshaped, past_key_value class ProphetNetFeedForward(nn.Module): @@ -779,8 +789,8 @@ class ProphetNetNgramSelfAttention(nn.Module): # for onnx runtime self.onnx_trace = False - def _reshape(self, tensor, first_dim, batch_size): - return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1) + def _shape(self, tensor, seq_len, batch_size): + return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() def prepare_for_onnx_export_(self): self.onnx_trace = True @@ -788,23 +798,20 @@ class ProphetNetNgramSelfAttention(nn.Module): def forward( self, hidden_states, - layer_state=None, + past_key_value: Optional[Tuple[Tensor]] = None, attention_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, ): - sequence_length, batch_size, hidden_size = hidden_states.size() + batch_size, ngram_sequence_length, hidden_size = hidden_states.size() assert list(hidden_states.size()) == [ - sequence_length, batch_size, + ngram_sequence_length, hidden_size, - ], f"`hidden_states` should be of shape {sequence_length, batch_size, hidden_size}, but is of shape {hidden_states.shape}" - - # key and value of previous time steps are cached - saved_state = layer_state.get("self", None) + ], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}" # project query_states = self.query_proj(hidden_states) @@ -815,12 +822,18 @@ class ProphetNetNgramSelfAttention(nn.Module): query_states = query_states / (self.head_dim ** 0.5) # reshape - query_states = self._reshape(query_states, sequence_length, batch_size) - key_states = self._reshape(key_states, -1, batch_size) - value_states = self._reshape(value_states, -1, batch_size) + query_states = self._shape(query_states, ngram_sequence_length, batch_size) + key_states = self._shape(key_states, -1, batch_size) + value_states = self._shape(value_states, -1, batch_size) + + proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + + query_states = query_states.view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) # chunk into main stream and predict stream - hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=0) + hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) query_states_list = query_states.chunk(1 + self.ngram, dim=1) key_states_list = key_states.chunk(1 + self.ngram, dim=1) @@ -832,24 +845,20 @@ class ProphetNetNgramSelfAttention(nn.Module): main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) - if saved_state is not None: - prev_main_key_states = saved_state["prev_key_states"].view( - batch_size * self.num_attn_heads, -1, self.head_dim - ) + if past_key_value is not None: + prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim) main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) - prev_main_value_states = saved_state["prev_value_states"].view( - batch_size * self.num_attn_heads, -1, self.head_dim - ) + prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim) main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) # Update cache - layer_state["self"] = { - "prev_key_states": main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - "prev_value_states": main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - } + past_key_value = ( + main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), + main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), + ) # get seq_length of main stream only - main_sequence_length = sequence_length // (1 + self.ngram) + sequence_length = ngram_sequence_length // (1 + self.ngram) # MAIN-STREAM # main attn weights @@ -871,18 +880,21 @@ class ProphetNetNgramSelfAttention(nn.Module): ).type_as(main_attn_weights) main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) - # project to attn_output main_attn_output = torch.bmm(main_attn_probs, main_value_states) + + # reshape so that num_heads dim is merged into last `head_dim` axis main_attn_output = ( - main_attn_output.transpose(0, 1).contiguous().view(1, main_sequence_length, batch_size, hidden_size) + main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim) + .transpose(1, 2) + .reshape(batch_size, 1, sequence_length, hidden_size) ) main_attn_output = self.out_proj(main_attn_output) # PREDICT-STREAM # [ngram, B*head, T, c] predict_query_states = torch.cat(predict_query_states_list, 0).view( - self.ngram, -1, main_sequence_length, self.head_dim + self.ngram, -1, sequence_length, self.head_dim ) # [ngram, B*head, 2*T, c] predict_key_states = torch.cat( @@ -891,7 +903,7 @@ class ProphetNetNgramSelfAttention(nn.Module): # [ngram, T, B, C] predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( - self.ngram, main_sequence_length, batch_size, hidden_size + self.ngram, sequence_length, batch_size, hidden_size ) # [ngram, B*head, 2*T, c] @@ -911,7 +923,9 @@ class ProphetNetNgramSelfAttention(nn.Module): predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings if extended_predict_attention_mask is not None: - predict_attn_weights = predict_attn_weights + extended_predict_attention_mask + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to( + predict_attn_weights.dtype + ) predict_attn_probs = softmax( predict_attn_weights, @@ -919,35 +933,36 @@ class ProphetNetNgramSelfAttention(nn.Module): onnx_trace=self.onnx_trace, ).type_as(predict_attn_weights) predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) - # project to attention output # [ngram, B*head, T, c] predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) - # [ngram, T, B, C] + + # reshape so that num_heads dim is merged into last `head_dim` axis + # [ngram, B, T, C] predict_attn_output = ( - predict_attn_output.transpose(1, 2) - .contiguous() - .view(self.ngram, main_sequence_length, batch_size, hidden_size) + predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim) + .permute(1, 0, 3, 2, 4) + .reshape(batch_size, self.ngram, sequence_length, hidden_size) ) predict_attn_output = self.out_proj(predict_attn_output) # concat to single attn output - # [1+ngram*T, B, C] - attn_output = torch.cat([main_attn_output, predict_attn_output], 0).view(-1, batch_size, hidden_size) - + # [B, 1+ngram*T, C] + attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) # reshape into better form for `config.output_attentions` - main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, main_sequence_length, -1) + main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) predict_attn_probs = predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, main_sequence_length, -1 + self.ngram, batch_size, self.num_attn_heads, sequence_length, -1 ).transpose(0, 1) attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, main_attn_probs, predict_attn_probs + + return attn_output, main_attn_probs, predict_attn_probs, past_key_value def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets ): - # input hidden_states [T,B,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] + # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] if main_relative_position_buckets is None: batch_size, sequence_length = hidden_states.shape[:2] @@ -965,7 +980,6 @@ class ProphetNetNgramSelfAttention(nn.Module): self.num_buckets, self.relative_max_distance, relative_positions, False ) - hidden_states = hidden_states.transpose(0, 1) # [B,T,C] rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) @@ -991,7 +1005,6 @@ class ProphetNetNgramSelfAttention(nn.Module): self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets ): # input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None - sequence_length, batch_size = hidden_states.shape[1:3] if predict_relative_position_buckets is None: @@ -1053,18 +1066,25 @@ class ProphetNetEncoderLayer(nn.Module): self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) - def forward(self, hidden_states, attention_mask): + def forward(self, hidden_states, attention_mask, output_attentions: bool = False): # 1st residual block - attention_output, attn_weights = self.self_attn( + attention_output, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + output_attentions=output_attentions, ) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) # 2nd residual block feed_forward_output = self.feed_forward(hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) - return hidden_states, attn_weights + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs class ProphetNetDecoderLayer(nn.Module): @@ -1090,21 +1110,23 @@ class ProphetNetDecoderLayer(nn.Module): def forward( self, hidden_states, + attention_mask=None, encoder_hidden_states=None, encoder_attn_mask=None, - layer_state=None, - attention_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, + past_key_value=None, + use_cache: bool = True, + output_attentions: bool = False, ): - layer_state = layer_state if layer_state is not None else {} - # 1st residual block - ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( hidden_states=hidden_states, - layer_state=layer_state, + past_key_value=self_attn_past_key_value, attention_mask=attention_mask, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, @@ -1113,28 +1135,36 @@ class ProphetNetDecoderLayer(nn.Module): ) hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block - attention_output, cross_attn_weights = self.cross_attn( + attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, - layer_state=layer_state, # mutates layer state + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + # 3rd residual block feed_forward_output = self.feed_forward(hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) - return ( - hidden_states, - self_attn_weights, - self_attn_weights_ngram, - cross_attn_weights, - layer_state, - ) # just self_attn weights for now, following t5, layer_state = cache for decoding + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs @add_start_docstrings( @@ -1223,21 +1253,37 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): hidden_states = inputs_embeds + position_embeddings hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - hidden_states = hidden_states.transpose(0, 1) # B x T x C -> T x B x C encoder_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: - hidden_states = hidden_states.transpose(0, 1) encoder_hidden_states = encoder_hidden_states + (hidden_states,) - hidden_states = hidden_states.transpose(0, 1) - hidden_states, attn_probs = encoder_layer(hidden_states, attention_mask=extended_attention_mask) - if output_attentions: - all_attentions = all_attentions + (attn_probs,) - hidden_states = hidden_states.transpose(0, 1) + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + extended_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) @@ -1370,26 +1416,24 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): # add position embeddings hidden_states = inputs_embeds + main_stream_pos_embed - hidden_states = hidden_states.transpose(0, 1) ngram_embeddings = self.ngram_embeddings.weight # prepare attention mask if past_key_values is not None: assert ( - hidden_states.size(0) == 1 + hidden_states.size(1) == 1 ), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" ngram_hidden_states = [ - (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1).repeat(1, batch_size, 1) + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1) for ngram in range(self.ngram) ] extended_attention_mask = None extended_predict_attention_mask = None else: ngram_hidden_states = [ - (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1) - for ngram in range(self.ngram) + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram) ] extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) @@ -1403,16 +1447,13 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): else: extended_encoder_attention_mask = None - hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 0) + hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1) if self.embeddings_layer_norm: hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.transpose(0, 1) - # init attentions, hidden_states and cache with empty tuples all_main_stream_hidden_states = () if output_hidden_states else None all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None @@ -1425,47 +1466,75 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: # grad cannot be kept because tensor is sliced - all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),) + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) if self.config.ngram > 0: - all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),) + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + extended_attention_mask, + encoder_hidden_states, + extended_encoder_attention_mask, + extended_predict_attention_mask, + main_relative_position_buckets, + predict_relative_position_buckets, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] - layer_state = past_key_values[idx] if past_key_values is not None else None - ( - hidden_states, - layer_self_attn, - layer_self_predict_attn_output, - layer_cross_attn, - layer_past, - ) = decoder_layer( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - layer_state=layer_state, - attention_mask=extended_attention_mask, - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - ) if use_cache: - present_key_values += (layer_past,) + present_key_values += (layer_outputs[4 if output_attentions else 1],) if output_attentions: - all_main_stream_attns += (layer_self_attn,) - all_ngram_stream_attns += (layer_self_predict_attn_output,) + all_main_stream_attns += (layer_outputs[1],) + all_ngram_stream_attns += (layer_outputs[2],) if self.config.add_cross_attention: - all_cross_attns += (layer_cross_attn,) + all_cross_attns += (layer_outputs[3],) if output_hidden_states: - all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),) + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) if self.config.ngram > 0: - all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),) + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) # split last_hidden_state for return - last_hidden_state = hidden_states[:sequence_length].transpose(0, 1) - last_hidden_state_ngram = hidden_states[sequence_length:].transpose(0, 1) if self.config.ngram > 0 else None - encoder_hidden_states = encoder_hidden_states.transpose(0, 1) if encoder_hidden_states is not None else None + last_hidden_state = hidden_states[:, :sequence_length] + last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None if not return_dict: return tuple( @@ -1516,7 +1585,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): return main_relative_buckets, predict_relative_buckets def prepare_attention_mask(self, hidden_states, attention_mask): - seq_length, batch_size = hidden_states.shape[:2] + batch_size, seq_length = hidden_states.shape[:2] # get causal mask causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf")) @@ -1534,7 +1603,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) def prepare_predict_attention_mask(self, hidden_states, attention_mask): - seq_length, batch_size = hidden_states.shape[:2] + batch_size, seq_length = hidden_states.shape[:2] # get causal mask predict_causal_mask = ngram_attention_bias( @@ -1656,7 +1725,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel): return_dict=return_dict, ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -1856,21 +1925,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): return self._shift_right(labels) @staticmethod + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache def _reorder_cache(past, beam_idx): - # this function reorders the cache for beam search - def _reorder_cache(cache_dict, beam_idx): - for k, key_value_states in cache_dict.items(): - if key_value_states is not None: - cache_dict[k] = key_value_states.index_select(0, beam_idx) - return cache_dict - - reordered_past = [] + reordered_past = () for layer_past in past: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { - attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() - } - reordered_past.append(layer_past_new) + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) return reordered_past def get_encoder(self): @@ -1995,7 +2057,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) outputs = self.prophetnet.decoder( input_ids=input_ids, attention_mask=attention_mask, @@ -2080,21 +2142,11 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): } @staticmethod + # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache def _reorder_cache(past, beam_idx): - # this function reorders the cache for beam search - def _reorder_cache(cache_dict, beam_idx): - for k, key_value_states in cache_dict.items(): - if key_value_states is not None: - cache_dict[k] = key_value_states.index_select(0, beam_idx) - return cache_dict - - reordered_past = [] + reordered_past = () for layer_past in past: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { - attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() - } - reordered_past.append(layer_past_new) + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index c9ba56396e1..91ea9f2c2b8 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -243,7 +243,7 @@ class ProphetNetModelTester: # There should be `num_layers` key value embeddings stored in decoder_past self.parent.assertEqual(len(decoder_past), config.num_decoder_layers) # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple - self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention + self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention def create_and_check_with_lm_head( self,