[ProphetNet] Bart-like Refactor (#10501)

* first step to refactor

* make all fast tests pass

* make all slow tests pass

* save intermediate

* correct cache

* finish PR

* make fp16 work
This commit is contained in:
Patrick von Platen 2021-03-04 23:27:12 +03:00 committed by GitHub
parent 6290169eb3
commit c503a1c15e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 187 deletions

View File

@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed. smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig):
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
gradient_checkpointing=False,
eps=0.0, eps=0.0,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
# 4 Training Args (should be removed soon)
self.gradient_checkpointing = gradient_checkpointing
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.num_encoder_attention_heads return self.num_encoder_attention_heads

View File

@ -18,7 +18,7 @@ import copy
import math import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -567,6 +567,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
""" """
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None): def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
@ -578,7 +579,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
if past_key_values is not None: if past_key_values is not None:
# position_ids is the same for every token when decoding a single step # position_ids is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX # Without the int() cast, it doesn't work in some cases when exporting to ONNX
prev_num_input_ids = past_key_values[0]["self"]["prev_key_states"].shape[2] prev_num_input_ids = past_key_values[0][0].shape[2]
num_input_ids = inputs_shape[1] + prev_num_input_ids num_input_ids = inputs_shape[1] + prev_num_input_ids
position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
int(self.padding_idx + num_input_ids) int(self.padding_idx + num_input_ids)
@ -592,6 +593,9 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() + self.padding_idx ).long() + self.padding_idx
# make sure position_ids are not bigger then max_length
position_ids = position_ids.clamp(0, self.max_length - 1)
return super().forward(position_ids), position_ids return super().forward(position_ids), position_ids
def _forward(self, position_ids): def _forward(self, position_ids):
@ -624,66 +628,65 @@ class ProphetNetAttention(nn.Module):
self.out_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size)
def _reshape(self, tensor, first_dim, batch_size): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1) return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states, hidden_states,
key_value_states: Optional[Tensor] = None, key_value_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None, past_key_value: Optional[Tuple[Tensor]] = None,
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
sequence_length, batch_size, hidden_size = hidden_states.size() batch_size, tgt_len, hidden_size = hidden_states.size()
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
cache_key = "cross_attention" if is_cross_attention else "self"
assert list(hidden_states.size()) == [ assert list(hidden_states.size()) == [
sequence_length,
batch_size, batch_size,
tgt_len,
hidden_size, hidden_size,
], f"Size of hidden states should be {sequence_length, batch_size, hidden_size}, but is {hidden_states.size()}" ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
# previous time steps are cached - no need to recompute key and value if they are static # previous time steps are cached - no need to recompute key and value if they are static
if layer_state is not None:
saved_state = layer_state.get(cache_key, None)
query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5) query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5)
query_states = self._reshape(query_states, sequence_length, batch_size)
if not is_cross_attention: if is_cross_attention and past_key_value is not None:
# self-attention # reuse k,v, cross_attentions
key_states = self.key_proj(hidden_states) key_states = past_key_value[0]
key_states = self._reshape(key_states, -1, batch_size) value_states = past_key_value[1]
value_states = self.value_proj(hidden_states) elif is_cross_attention:
value_states = self._reshape(value_states, -1, batch_size) # cross_attentions
elif saved_state is None: key_states = self._shape(self.key_proj(key_value_states), -1, batch_size)
# cross-attention without layer state value_states = self._shape(self.value_proj(key_value_states), -1, batch_size)
key_states = self.key_proj(key_value_states)
key_states = self._reshape(key_states, -1, batch_size)
value_states = self.value_proj(key_value_states)
value_states = self._reshape(value_states, -1, batch_size)
else: else:
key_states = saved_state["prev_key_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim) # self_attention
value_states = saved_state["prev_value_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim) key_states = self._shape(self.key_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.value_proj(hidden_states), -1, batch_size)
# Update cache
if is_cross_attention: if is_cross_attention:
layer_state[cache_key] = { # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
"prev_key_states": key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), # Further calls to cross_attention layer can then reuse all cross-attention
"prev_value_states": value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), # key/value_states (first "if" case)
} # if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
key_sequence_length = key_states.size(1) # project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == ( assert attn_weights.size() == (
batch_size * self.num_attn_heads, batch_size * self.num_attn_heads,
sequence_length, tgt_len,
key_sequence_length, src_len,
), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, sequence_length, key_sequence_length}, but is of size {attn_weights.shape}" ), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}"
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0: if attention_mask is not None and attention_mask.dim() == 0:
@ -691,19 +694,21 @@ class ProphetNetAttention(nn.Module):
assert attention_mask is None or attention_mask.size() == ( assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size, self.num_attn_heads * batch_size,
1, 1,
key_sequence_length, src_len,
), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, key_sequence_length}, but is {attention_mask.shape}" ), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
if attention_mask is not None: # don't attend to padding symbols if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# need two reshapes to keep gradient at attention weights if output_attentions:
attn_weights_reshaped = attn_weights.view( # this operation is a bit akward, but it's required to
batch_size, self.num_attn_heads, sequence_length, key_sequence_length # make sure that attn_weights keeps its gradient.
) # In order to do so, attn_weights have to reshaped
attn_weights = attn_weights_reshaped.view( # twice and have to be reused in the following
batch_size * self.num_attn_heads, sequence_length, key_sequence_length attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
) attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout( attn_probs = F.dropout(
@ -715,15 +720,20 @@ class ProphetNetAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == ( assert attn_output.size() == (
batch_size * self.num_attn_heads, batch_size * self.num_attn_heads,
sequence_length, tgt_len,
self.head_dim, self.head_dim,
), "`attn_output` should be of shape {batch_size * self.num_attn_heads, sequence_length, self.head_dim}, but is of shape {attn_output.size()}" ), "`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}"
attn_output = attn_output.transpose(0, 1).contiguous().view(sequence_length, batch_size, hidden_size)
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped, past_key_value
class ProphetNetFeedForward(nn.Module): class ProphetNetFeedForward(nn.Module):
@ -779,8 +789,8 @@ class ProphetNetNgramSelfAttention(nn.Module):
# for onnx runtime # for onnx runtime
self.onnx_trace = False self.onnx_trace = False
def _reshape(self, tensor, first_dim, batch_size): def _shape(self, tensor, seq_len, batch_size):
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1) return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):
self.onnx_trace = True self.onnx_trace = True
@ -788,23 +798,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
layer_state=None, past_key_value: Optional[Tuple[Tensor]] = None,
attention_mask=None, attention_mask=None,
extended_predict_attention_mask=None, extended_predict_attention_mask=None,
main_relative_position_buckets=None, main_relative_position_buckets=None,
predict_relative_position_buckets=None, predict_relative_position_buckets=None,
position_ids=None, position_ids=None,
): ):
sequence_length, batch_size, hidden_size = hidden_states.size() batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [ assert list(hidden_states.size()) == [
sequence_length,
batch_size, batch_size,
ngram_sequence_length,
hidden_size, hidden_size,
], f"`hidden_states` should be of shape {sequence_length, batch_size, hidden_size}, but is of shape {hidden_states.shape}" ], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}"
# key and value of previous time steps are cached
saved_state = layer_state.get("self", None)
# project # project
query_states = self.query_proj(hidden_states) query_states = self.query_proj(hidden_states)
@ -815,12 +822,18 @@ class ProphetNetNgramSelfAttention(nn.Module):
query_states = query_states / (self.head_dim ** 0.5) query_states = query_states / (self.head_dim ** 0.5)
# reshape # reshape
query_states = self._reshape(query_states, sequence_length, batch_size) query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._reshape(key_states, -1, batch_size) key_states = self._shape(key_states, -1, batch_size)
value_states = self._reshape(value_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
# chunk into main stream and predict stream # chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=0) hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=1) query_states_list = query_states.chunk(1 + self.ngram, dim=1)
key_states_list = key_states.chunk(1 + self.ngram, dim=1) key_states_list = key_states.chunk(1 + self.ngram, dim=1)
@ -832,24 +845,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if saved_state is not None: if past_key_value is not None:
prev_main_key_states = saved_state["prev_key_states"].view( prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
batch_size * self.num_attn_heads, -1, self.head_dim
)
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
prev_main_value_states = saved_state["prev_value_states"].view( prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
batch_size * self.num_attn_heads, -1, self.head_dim
)
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
# Update cache # Update cache
layer_state["self"] = { past_key_value = (
"prev_key_states": main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
"prev_value_states": main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
} )
# get seq_length of main stream only # get seq_length of main stream only
main_sequence_length = sequence_length // (1 + self.ngram) sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM # MAIN-STREAM
# main attn weights # main attn weights
@ -871,18 +880,21 @@ class ProphetNetNgramSelfAttention(nn.Module):
).type_as(main_attn_weights) ).type_as(main_attn_weights)
main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output # project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states) main_attn_output = torch.bmm(main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = ( main_attn_output = (
main_attn_output.transpose(0, 1).contiguous().view(1, main_sequence_length, batch_size, hidden_size) main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
) )
main_attn_output = self.out_proj(main_attn_output) main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM # PREDICT-STREAM
# [ngram, B*head, T, c] # [ngram, B*head, T, c]
predict_query_states = torch.cat(predict_query_states_list, 0).view( predict_query_states = torch.cat(predict_query_states_list, 0).view(
self.ngram, -1, main_sequence_length, self.head_dim self.ngram, -1, sequence_length, self.head_dim
) )
# [ngram, B*head, 2*T, c] # [ngram, B*head, 2*T, c]
predict_key_states = torch.cat( predict_key_states = torch.cat(
@ -891,7 +903,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
# [ngram, T, B, C] # [ngram, T, B, C]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
self.ngram, main_sequence_length, batch_size, hidden_size self.ngram, sequence_length, batch_size, hidden_size
) )
# [ngram, B*head, 2*T, c] # [ngram, B*head, 2*T, c]
@ -911,7 +923,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None: if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
predict_attn_weights.dtype
)
predict_attn_probs = softmax( predict_attn_probs = softmax(
predict_attn_weights, predict_attn_weights,
@ -919,35 +933,36 @@ class ProphetNetNgramSelfAttention(nn.Module):
onnx_trace=self.onnx_trace, onnx_trace=self.onnx_trace,
).type_as(predict_attn_weights) ).type_as(predict_attn_weights)
predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
# project to attention output # project to attention output
# [ngram, B*head, T, c] # [ngram, B*head, T, c]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
# [ngram, T, B, C]
# reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C]
predict_attn_output = ( predict_attn_output = (
predict_attn_output.transpose(1, 2) predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.contiguous() .permute(1, 0, 3, 2, 4)
.view(self.ngram, main_sequence_length, batch_size, hidden_size) .reshape(batch_size, self.ngram, sequence_length, hidden_size)
) )
predict_attn_output = self.out_proj(predict_attn_output) predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output # concat to single attn output
# [1+ngram*T, B, C] # [B, 1+ngram*T, C]
attn_output = torch.cat([main_attn_output, predict_attn_output], 0).view(-1, batch_size, hidden_size) attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions` # reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, main_sequence_length, -1) main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view( predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, main_sequence_length, -1 self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1) ).transpose(0, 1)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, main_attn_probs, predict_attn_probs
return attn_output, main_attn_probs, predict_attn_probs, past_key_value
def get_main_relative_pos_embeddings( def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
): ):
# input hidden_states [T,B,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
if main_relative_position_buckets is None: if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2] batch_size, sequence_length = hidden_states.shape[:2]
@ -965,7 +980,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False self.num_buckets, self.relative_max_distance, relative_positions, False
) )
hidden_states = hidden_states.transpose(0, 1) # [B,T,C]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
@ -991,7 +1005,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
): ):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None # input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
sequence_length, batch_size = hidden_states.shape[1:3] sequence_length, batch_size = hidden_states.shape[1:3]
if predict_relative_position_buckets is None: if predict_relative_position_buckets is None:
@ -1053,18 +1066,25 @@ class ProphetNetEncoderLayer(nn.Module):
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask, output_attentions: bool = False):
# 1st residual block # 1st residual block
attention_output, attn_weights = self.self_attn( attention_output, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions,
) )
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
# 2nd residual block # 2nd residual block
feed_forward_output = self.feed_forward(hidden_states) feed_forward_output = self.feed_forward(hidden_states)
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
return hidden_states, attn_weights
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class ProphetNetDecoderLayer(nn.Module): class ProphetNetDecoderLayer(nn.Module):
@ -1090,21 +1110,23 @@ class ProphetNetDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attn_mask=None, encoder_attn_mask=None,
layer_state=None,
attention_mask=None,
extended_predict_attention_mask=None, extended_predict_attention_mask=None,
main_relative_position_buckets=None, main_relative_position_buckets=None,
predict_relative_position_buckets=None, predict_relative_position_buckets=None,
position_ids=None, position_ids=None,
past_key_value=None,
use_cache: bool = True,
output_attentions: bool = False,
): ):
layer_state = layer_state if layer_state is not None else {}
# 1st residual block # 1st residual block
ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
layer_state=layer_state, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
extended_predict_attention_mask=extended_predict_attention_mask, extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets, main_relative_position_buckets=main_relative_position_buckets,
@ -1113,28 +1135,36 @@ class ProphetNetDecoderLayer(nn.Module):
) )
hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attn_weights = None cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
# 2nd residual block # 2nd residual block
attention_output, cross_attn_weights = self.cross_attn( attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attn_mask, attention_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
) )
hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# 3rd residual block # 3rd residual block
feed_forward_output = self.feed_forward(hidden_states) feed_forward_output = self.feed_forward(hidden_states)
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
return ( outputs = (hidden_states,)
hidden_states,
self_attn_weights, if output_attentions:
self_attn_weights_ngram, outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
cross_attn_weights,
layer_state, if use_cache:
) # just self_attn weights for now, following t5, layer_state = cache for decoding outputs += (present_key_value,)
return outputs
@add_start_docstrings( @add_start_docstrings(
@ -1223,21 +1253,37 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
hidden_states = inputs_embeds + position_embeddings hidden_states = inputs_embeds + position_embeddings
hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
hidden_states = hidden_states.transpose(0, 1) # B x T x C -> T x B x C
encoder_hidden_states = () if output_hidden_states else None encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for encoder_layer in self.layers: for encoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
hidden_states = hidden_states.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
hidden_states = hidden_states.transpose(0, 1)
hidden_states, attn_probs = encoder_layer(hidden_states, attention_mask=extended_attention_mask)
if output_attentions:
all_attentions = all_attentions + (attn_probs,)
hidden_states = hidden_states.transpose(0, 1) if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
extended_attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
@ -1370,26 +1416,24 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# add position embeddings # add position embeddings
hidden_states = inputs_embeds + main_stream_pos_embed hidden_states = inputs_embeds + main_stream_pos_embed
hidden_states = hidden_states.transpose(0, 1)
ngram_embeddings = self.ngram_embeddings.weight ngram_embeddings = self.ngram_embeddings.weight
# prepare attention mask # prepare attention mask
if past_key_values is not None: if past_key_values is not None:
assert ( assert (
hidden_states.size(0) == 1 hidden_states.size(1) == 1
), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" ), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
ngram_hidden_states = [ ngram_hidden_states = [
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1).repeat(1, batch_size, 1) (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
for ngram in range(self.ngram) for ngram in range(self.ngram)
] ]
extended_attention_mask = None extended_attention_mask = None
extended_predict_attention_mask = None extended_predict_attention_mask = None
else: else:
ngram_hidden_states = [ ngram_hidden_states = [
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1) (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
for ngram in range(self.ngram)
] ]
extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
@ -1403,16 +1447,13 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
else: else:
extended_encoder_attention_mask = None extended_encoder_attention_mask = None
hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 0) hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
if self.embeddings_layer_norm: if self.embeddings_layer_norm:
hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# init attentions, hidden_states and cache with empty tuples # init attentions, hidden_states and cache with empty tuples
all_main_stream_hidden_states = () if output_hidden_states else None all_main_stream_hidden_states = () if output_hidden_states else None
all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
@ -1425,47 +1466,75 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
# grad cannot be kept because tensor is sliced # grad cannot be kept because tensor is sliced
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),) all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
if self.config.ngram > 0: if self.config.ngram > 0:
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),) all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
layer_state = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
(
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
layer_self_attn, extended_attention_mask,
layer_self_predict_attn_output, encoder_hidden_states,
layer_cross_attn, extended_encoder_attention_mask,
layer_past, extended_predict_attention_mask,
) = decoder_layer( main_relative_position_buckets,
predict_relative_position_buckets,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask, encoder_attn_mask=extended_encoder_attention_mask,
layer_state=layer_state,
attention_mask=extended_attention_mask,
extended_predict_attention_mask=extended_predict_attention_mask, extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets, main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = layer_outputs[0]
if use_cache: if use_cache:
present_key_values += (layer_past,) present_key_values += (layer_outputs[4 if output_attentions else 1],)
if output_attentions: if output_attentions:
all_main_stream_attns += (layer_self_attn,) all_main_stream_attns += (layer_outputs[1],)
all_ngram_stream_attns += (layer_self_predict_attn_output,) all_ngram_stream_attns += (layer_outputs[2],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_outputs[3],)
if output_hidden_states: if output_hidden_states:
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),) all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
if self.config.ngram > 0: if self.config.ngram > 0:
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),) all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
# split last_hidden_state for return # split last_hidden_state for return
last_hidden_state = hidden_states[:sequence_length].transpose(0, 1) last_hidden_state = hidden_states[:, :sequence_length]
last_hidden_state_ngram = hidden_states[sequence_length:].transpose(0, 1) if self.config.ngram > 0 else None last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) if encoder_hidden_states is not None else None
if not return_dict: if not return_dict:
return tuple( return tuple(
@ -1516,7 +1585,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return main_relative_buckets, predict_relative_buckets return main_relative_buckets, predict_relative_buckets
def prepare_attention_mask(self, hidden_states, attention_mask): def prepare_attention_mask(self, hidden_states, attention_mask):
seq_length, batch_size = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
# get causal mask # get causal mask
causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf")) causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf"))
@ -1534,7 +1603,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask): def prepare_predict_attention_mask(self, hidden_states, attention_mask):
seq_length, batch_size = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
# get causal mask # get causal mask
predict_causal_mask = ngram_attention_bias( predict_causal_mask = ngram_attention_bias(
@ -1656,7 +1725,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
@ -1856,21 +1925,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return self._shift_right(labels) return self._shift_right(labels)
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
# this function reorders the cache for beam search reordered_past = ()
def _reorder_cache(cache_dict, beam_idx):
for k, key_value_states in cache_dict.items():
if key_value_states is not None:
cache_dict[k] = key_value_states.index_select(0, beam_idx)
return cache_dict
reordered_past = []
for layer_past in past: for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn # cached cross_attention states don't have to be reordered -> they are always the same
layer_past_new = { reordered_past += (
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
} )
reordered_past.append(layer_past_new)
return reordered_past return reordered_past
def get_encoder(self): def get_encoder(self):
@ -1995,7 +2057,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
outputs = self.prophetnet.decoder( outputs = self.prophetnet.decoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -2080,21 +2142,11 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
} }
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
# this function reorders the cache for beam search reordered_past = ()
def _reorder_cache(cache_dict, beam_idx):
for k, key_value_states in cache_dict.items():
if key_value_states is not None:
cache_dict[k] = key_value_states.index_select(0, beam_idx)
return cache_dict
reordered_past = []
for layer_past in past: for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
layer_past_new = {
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past return reordered_past

View File

@ -243,7 +243,7 @@ class ProphetNetModelTester:
# There should be `num_layers` key value embeddings stored in decoder_past # There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers) self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
def create_and_check_with_lm_head( def create_and_check_with_lm_head(
self, self,