mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Refactor OPT model (#36101)
* remove cross attention Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * remove is_decoder Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix pkv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
924f1c717a
commit
0baf003915
@ -98,7 +98,6 @@ class OPTAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OPTConfig,
|
config: OPTConfig,
|
||||||
is_decoder: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -117,7 +116,6 @@ class OPTAttention(nn.Module):
|
|||||||
f" and `num_heads`: {self.num_heads})."
|
f" and `num_heads`: {self.num_heads})."
|
||||||
)
|
)
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.is_decoder = is_decoder
|
|
||||||
|
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
||||||
@ -130,7 +128,6 @@ class OPTAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
@ -139,44 +136,19 @@ class OPTAttention(nn.Module):
|
|||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
# get key, value proj
|
# get key, value proj
|
||||||
if is_cross_attention and past_key_value is not None:
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
# reuse k,v, cross_attentions
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
key_states = past_key_value[0]
|
if past_key_value is not None:
|
||||||
value_states = past_key_value[1]
|
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
else:
|
|
||||||
# self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
if self.is_decoder:
|
past_key_value = (key_states, value_states)
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
@ -268,7 +240,6 @@ class OptFlashAttention2(OPTAttention):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
@ -276,44 +247,19 @@ class OptFlashAttention2(OPTAttention):
|
|||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
|
|
||||||
bsz, _, _ = hidden_states.size()
|
bsz, _, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
# get key, value proj
|
# get key, value proj
|
||||||
if is_cross_attention and past_key_value is not None:
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
# reuse k,v, cross_attentions
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
key_states = past_key_value[0]
|
if past_key_value is not None:
|
||||||
value_states = past_key_value[1]
|
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
else:
|
|
||||||
# self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
if self.is_decoder:
|
past_key_value = (key_states, value_states)
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
query_length = query_states.shape[1]
|
query_length = query_states.shape[1]
|
||||||
tgt_len = key_states.shape[-2]
|
tgt_len = key_states.shape[-2]
|
||||||
@ -380,7 +326,6 @@ class OPTSdpaAttention(OPTAttention):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
@ -399,9 +344,7 @@ class OPTSdpaAttention(OPTAttention):
|
|||||||
layer_head_mask=layer_head_mask,
|
layer_head_mask=layer_head_mask,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
key_value_states=key_value_states,
|
|
||||||
) # TODO after merge add position_ids=position_ids
|
) # TODO after merge add position_ids=position_ids
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -409,34 +352,14 @@ class OPTSdpaAttention(OPTAttention):
|
|||||||
query_states = self._shape(query_states, -1, bsz)
|
query_states = self._shape(query_states, -1, bsz)
|
||||||
|
|
||||||
# get key, value proj
|
# get key, value proj
|
||||||
if is_cross_attention and past_key_value is not None:
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
# reuse k,v, cross_attentions
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
key_states = past_key_value[0]
|
if past_key_value is not None:
|
||||||
value_states = past_key_value[1]
|
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
else:
|
|
||||||
# self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
if self.is_decoder:
|
past_key_value = (key_states, value_states)
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
|
# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
|
||||||
|
|
||||||
@ -480,7 +403,7 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
|
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
||||||
|
|
||||||
self.do_layer_norm_before = config.do_layer_norm_before
|
self.do_layer_norm_before = config.do_layer_norm_before
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
|
Loading…
Reference in New Issue
Block a user