mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[cleanup] factor out get_head_mask, invert_attn_mask, get_exten… (#3806)
* Delete some copy pasted code
This commit is contained in:
parent
d22894dfd4
commit
dbd041243d
@ -552,19 +552,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
if head_mask is not None:
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
|
@ -703,36 +703,9 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
if attention_mask.dim() == 3:
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
attention_mask, input_shape, self.device
|
||||||
elif attention_mask.dim() == 2:
|
)
|
||||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if self.config.is_decoder:
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
causal_mask = causal_mask.to(
|
|
||||||
attention_mask.dtype
|
|
||||||
) # causal and attention masks must have same type with pytorch version < 1.3
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
|
||||||
input_shape, attention_mask.shape
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||||
@ -741,22 +714,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
if encoder_attention_mask.dim() == 3:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
|
||||||
elif encoder_attention_mask.dim() == 2:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
|
|
||||||
encoder_hidden_shape, encoder_attention_mask.shape
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
|
||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
@ -765,19 +723,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
if head_mask is not None:
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
|
@ -392,26 +392,11 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.n_layer
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
@ -460,23 +460,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||||
|
@ -164,65 +164,6 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_electra
|
load_tf_weights = load_tf_weights_in_electra
|
||||||
base_model_prefix = "electra"
|
base_model_prefix = "electra"
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask, input_shape, device):
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
||||||
if attention_mask.dim() == 3:
|
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
|
||||||
elif attention_mask.dim() == 2:
|
|
||||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if self.config.is_decoder:
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
causal_mask = causal_mask.to(
|
|
||||||
attention_mask.dtype
|
|
||||||
) # causal and attention masks must have same type with pytorch version < 1.3
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
|
||||||
input_shape, attention_mask.shape
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
return extended_attention_mask
|
|
||||||
|
|
||||||
def get_head_mask(self, head_mask):
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
num_hidden_layers = self.config.num_hidden_layers
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * num_hidden_layers
|
|
||||||
|
|
||||||
return head_mask
|
|
||||||
|
|
||||||
|
|
||||||
ELECTRA_START_DOCSTRING = r"""
|
ELECTRA_START_DOCSTRING = r"""
|
||||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||||
@ -376,7 +317,7 @@ class ElectraModel(ElectraPreTrainedModel):
|
|||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
head_mask = self.get_head_mask(head_mask)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
|
@ -201,23 +201,7 @@ class FlaubertModel(XLMModel):
|
|||||||
# langs = langs.transpose(0, 1)
|
# langs = langs.transpose(0, 1)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.n_layers)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.n_layers
|
|
||||||
|
|
||||||
# do not recompute cached elements
|
# do not recompute cached elements
|
||||||
if cache is not None and input_ids is not None:
|
if cache is not None and input_ids is not None:
|
||||||
|
@ -471,19 +471,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
if head_mask is not None:
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.n_layer
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
@ -23,6 +23,7 @@ import torch.nn as nn
|
|||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_utils import ModuleUtilsMixin
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -148,7 +149,7 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
MMBT_START_DOCSTRING,
|
MMBT_START_DOCSTRING,
|
||||||
MMBT_INPUTS_DOCSTRING,
|
MMBT_INPUTS_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MMBTModel(nn.Module):
|
class MMBTModel(ModuleUtilsMixin):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||||
@ -237,7 +238,6 @@ class MMBTModel(nn.Module):
|
|||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
[torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
|
[torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||||
else:
|
else:
|
||||||
@ -245,61 +245,9 @@ class MMBTModel(nn.Module):
|
|||||||
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
|
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
if attention_mask.dim() == 3:
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
|
||||||
|
|
||||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if attention_mask.dim() == 2:
|
|
||||||
if self.config.is_decoder:
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if encoder_attention_mask.dim() == 3:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
|
||||||
if encoder_attention_mask.dim() == 2:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
encoder_outputs = self.transformer.encoder(
|
encoder_outputs = self.transformer.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
|
@ -425,22 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.n_layer
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.tokens_embed(input_ids)
|
inputs_embeds = self.tokens_embed(input_ids)
|
||||||
|
@ -184,7 +184,7 @@ class T5LayerFF(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T5Attention(nn.Module):
|
class T5Attention(nn.Module):
|
||||||
def __init__(self, config, has_relative_attention_bias=False):
|
def __init__(self, config: T5Config, has_relative_attention_bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
self.has_relative_attention_bias = has_relative_attention_bias
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
@ -693,73 +693,15 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
if attention_mask.dim() == 3:
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
|
||||||
elif attention_mask.dim() == 2:
|
|
||||||
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
|
||||||
if self.config.is_decoder:
|
|
||||||
seq_ids = torch.arange(mask_seq_length, device=inputs_embeds.device)
|
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
causal_mask = causal_mask.to(attention_mask)
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
if past_key_value_states[0] is not None:
|
|
||||||
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
|
||||||
else:
|
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -1e9 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
|
|
||||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
|
||||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
|
||||||
# extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
|
|
||||||
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
|
||||||
|
|
||||||
if self.is_decoder and encoder_attention_mask is not None:
|
if self.is_decoder and encoder_attention_mask is not None:
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
|
||||||
if encoder_attention_mask.dim() == 3:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
|
||||||
if encoder_attention_mask.dim() == 2:
|
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
|
||||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
|
||||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
|
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
|
||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_layers
|
|
||||||
|
|
||||||
present_key_value_states = ()
|
present_key_value_states = ()
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import typing
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import Tensor, device, dtype, nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
@ -109,9 +109,102 @@ class ModuleUtilsMixin:
|
|||||||
module.mem_rss_pre_forward = 0
|
module.mem_rss_pre_forward = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self) -> device:
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> dtype:
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||||
|
"""type: torch.Tensor -> torch.Tensor"""
|
||||||
|
if encoder_attention_mask.dim() == 3:
|
||||||
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||||
|
if encoder_attention_mask.dim() == 2:
|
||||||
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
|
||||||
|
# /transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||||
|
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||||
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||||
|
return encoder_extended_attention_mask
|
||||||
|
|
||||||
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||||
|
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
|
||||||
|
input_shape: tuple, shape of input_ids
|
||||||
|
device: torch.Device, usually self.device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor with dtype of attention_mask.dtype
|
||||||
|
"""
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.is_decoder:
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||||
|
input_shape, attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
|
def get_head_mask(self, head_mask, num_hidden_layers):
|
||||||
|
"""
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
attention_probs has shape bsz x n_heads x N x N
|
||||||
|
Arguments:
|
||||||
|
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
num_hidden_layers: int
|
||||||
|
Returns:
|
||||||
|
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
or list with [None] for each layer
|
||||||
|
"""
|
||||||
|
if head_mask is not None:
|
||||||
|
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
||||||
|
else:
|
||||||
|
head_mask = [None] * num_hidden_layers
|
||||||
|
|
||||||
|
return head_mask
|
||||||
|
|
||||||
|
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
||||||
|
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
||||||
|
elif head_mask.dim() == 2:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
|
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||||
|
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
||||||
|
return head_mask
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||||
r""" Base class for all models.
|
r""" Base class for all models.
|
||||||
@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
|
|
||||||
if hasattr(self.config, "xla_device") and self.config.xla_device:
|
if getattr(self.config, "xla_device", False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if xm.is_master_ordinal():
|
if xm.is_master_ordinal():
|
||||||
@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = model
|
model_to_load = model
|
||||||
if not hasattr(model, cls.base_model_prefix) and any(
|
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
||||||
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
|
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
||||||
):
|
|
||||||
start_prefix = cls.base_model_prefix + "."
|
start_prefix = cls.base_model_prefix + "."
|
||||||
if hasattr(model, cls.base_model_prefix) and not any(
|
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
|
|
||||||
):
|
|
||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
load(model_to_load, prefix=start_prefix)
|
||||||
@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
)
|
)
|
||||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||||
|
|
||||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# get encoder and store encoder outputs
|
# get encoder and store encoder outputs
|
||||||
encoder = self.get_encoder()
|
encoder = self.get_encoder()
|
||||||
|
|
||||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
||||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||||
|
|
||||||
|
|
||||||
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
|
||||||
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||||
if cur_len + 1 < no_repeat_ngram_size:
|
if cur_len + 1 < no_repeat_ngram_size:
|
||||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||||
return [[] for _ in range(num_hypos)]
|
return [[] for _ in range(num_hypos)]
|
||||||
@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module):
|
|||||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
activation_string = getattr(config, "summary_activation", None)
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
self.activation = (
|
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
|
||||||
get_activation(activation_string) if activation_string else Identity()
|
|
||||||
) # type: typing.Callable
|
|
||||||
|
|
||||||
self.first_dropout = Identity()
|
self.first_dropout = Identity()
|
||||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
@ -479,23 +479,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
# langs = langs.transpose(0, 1)
|
# langs = langs.transpose(0, 1)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
head_mask = self.get_head_mask(head_mask, self.config.n_layers)
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.n_layers
|
|
||||||
|
|
||||||
# do not recompute cached elements
|
# do not recompute cached elements
|
||||||
if cache is not None and input_ids is not None:
|
if cache is not None and input_ids is not None:
|
||||||
|
@ -349,10 +349,12 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# (this can be done with self.invert_attention_mask)
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
@ -368,19 +370,7 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
if head_mask is not None:
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = (
|
|
||||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
# Replace this with your model code
|
# Replace this with your model code
|
||||||
|
Loading…
Reference in New Issue
Block a user