diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index eadb5fa195e..a231f024392 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -552,19 +552,7 @@ class AlbertModel(AlbertPreTrainedModel): extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index d0231d5bd18..f9c02ada97b 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -703,36 +703,9 @@ class BertModel(BertPreTrainedModel): # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - causal_mask = causal_mask.to( - attention_mask.dtype - ) # causal and attention masks must have same type with pytorch version < 1.3 - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, self.device + ) # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] @@ -741,22 +714,7 @@ class BertModel(BertPreTrainedModel): encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - elif encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( - encoder_hidden_shape, encoder_attention_mask.shape - ) - ) - - encoder_extended_attention_mask = encoder_extended_attention_mask.to( - dtype=next(self.parameters()).dtype - ) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -765,19 +723,7 @@ class BertModel(BertPreTrainedModel): # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 377655cefae..6c6a264cb82 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -392,26 +392,11 @@ class CTRLModel(CTRLPreTrainedModel): # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.n_layer + head_mask = self.get_head_mask(head_mask, self.config.n_layer) if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) diff --git a/src/transformers/modeling_distilbert.py b/src/transformers/modeling_distilbert.py index ce715034ac8..3dff00fa03c 100644 --- a/src/transformers/modeling_distilbert.py +++ b/src/transformers/modeling_distilbert.py @@ -460,23 +460,7 @@ class DistilBertModel(DistilBertPreTrainedModel): attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index 2d46716c357..ffe66b073a2 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -164,65 +164,6 @@ class ElectraPreTrainedModel(BertPreTrainedModel): load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" - def get_extended_attention_mask(self, attention_mask, input_shape, device): - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - causal_mask = causal_mask.to( - attention_mask.dtype - ) # causal and attention masks must have same type with pytorch version < 1.3 - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - return extended_attention_mask - - def get_head_mask(self, head_mask): - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - num_hidden_layers = self.config.num_hidden_layers - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * num_hidden_layers - - return head_mask - ELECTRA_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. @@ -376,7 +317,7 @@ class ElectraModel(ElectraPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) - head_mask = self.get_head_mask(head_mask) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds diff --git a/src/transformers/modeling_flaubert.py b/src/transformers/modeling_flaubert.py index 7236e44a163..80da729521a 100644 --- a/src/transformers/modeling_flaubert.py +++ b/src/transformers/modeling_flaubert.py @@ -201,23 +201,7 @@ class FlaubertModel(XLMModel): # langs = langs.transpose(0, 1) # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.n_layers + head_mask = self.get_head_mask(head_mask, self.config.n_layers) # do not recompute cached elements if cache is not None and input_ids is not None: diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 75717d75237..120139964e3 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -471,19 +471,7 @@ class GPT2Model(GPT2PreTrainedModel): # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.n_layer + head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) diff --git a/src/transformers/modeling_mmbt.py b/src/transformers/modeling_mmbt.py index a3aae389658..0eddaa72f0e 100644 --- a/src/transformers/modeling_mmbt.py +++ b/src/transformers/modeling_mmbt.py @@ -23,6 +23,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import add_start_docstrings +from .modeling_utils import ModuleUtilsMixin logger = logging.getLogger(__name__) @@ -148,7 +149,7 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs: MMBT_START_DOCSTRING, MMBT_INPUTS_DOCSTRING, ) -class MMBTModel(nn.Module): +class MMBTModel(ModuleUtilsMixin): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` @@ -237,7 +238,6 @@ class MMBTModel(nn.Module): attention_mask = torch.cat( [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1 ) - if encoder_attention_mask is None: encoder_attention_mask = torch.ones(input_shape, device=device) else: @@ -245,61 +245,9 @@ class MMBTModel(nn.Module): [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1 ) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if attention_mask.dim() == 2: - if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - encoder_extended_attention_mask = encoder_extended_attention_mask.to( - dtype=next(self.parameters()).dtype - ) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.transformer.encoder( embedding_output, diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index c81620fd1b5..3e65495f834 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -425,22 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): attention_mask = (1.0 - attention_mask) * -10000.0 # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.n_layer + head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.tokens_embed(input_ids) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index e633bc8b176..e78db03905a 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -184,7 +184,7 @@ class T5LayerFF(nn.Module): class T5Attention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config: T5Config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -693,73 +693,15 @@ class T5Stack(T5PreTrainedModel): # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.config.is_decoder: - seq_ids = torch.arange(mask_seq_length, device=inputs_embeds.device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None] - causal_mask = causal_mask.to(attention_mask) - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - if past_key_value_states[0] is not None: - extended_attention_mask = extended_attention_mask[:, :, -1:, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -1e9 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2)) - - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2)) - - encoder_extended_attention_mask = encoder_extended_attention_mask.to( - dtype=next(self.parameters()).dtype - ) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_layers - + head_mask = self.get_head_mask(head_mask, self.config.num_layers) present_key_value_states = () all_hidden_states = () all_attentions = () diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0f7143c22fe..f0df0c1ee51 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,10 +17,10 @@ import logging import os -import typing +from typing import Callable, Tuple import torch -from torch import nn +from torch import Tensor, device, dtype, nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F @@ -109,9 +109,102 @@ class ModuleUtilsMixin: module.mem_rss_pre_forward = 0 @property - def device(self): + def device(self) -> device: return next(self.parameters()).device + @property + def dtype(self) -> dtype: + return next(self.parameters()).dtype + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """type: torch.Tensor -> torch.Tensor""" + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + return encoder_extended_attention_mask + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): + """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. + + Arguments: + attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to + input_shape: tuple, shape of input_ids + device: torch.Device, usually self.device + + Returns: + torch.Tensor with dtype of attention_mask.dtype + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def get_head_mask(self, head_mask, num_hidden_layers): + """ + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + attention_probs has shape bsz x n_heads x N x N + Arguments: + head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads] + num_hidden_layers: int + Returns: + Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + or list with [None] for each layer + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility + return head_mask + class PreTrainedModel(nn.Module, ModuleUtilsMixin): r""" Base class for all models. @@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - if hasattr(self.config, "xla_device") and self.config.xla_device: + if getattr(self.config, "xla_device", False): import torch_xla.core.xla_model as xm if xm.is_master_ordinal(): @@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model - if not hasattr(model, cls.base_model_prefix) and any( - s.startswith(cls.base_model_prefix) for s in state_dict.keys() - ): + has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) + if not hasattr(model, cls.base_model_prefix) and has_prefix_module: start_prefix = cls.base_model_prefix + "." - if hasattr(model, cls.base_model_prefix) and not any( - s.startswith(cls.base_model_prefix) for s in state_dict.keys() - ): + if hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) @@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) model.tie_weights() # make sure token embedding weights are still tied if needed - # Set model in evaluation mode to desactivate DropOut modules by default + # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: @@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # get encoder and store encoder outputs encoder = self.get_encoder() - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask) # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: @@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores[:, all_but_token_ids_mask] = -float("inf") @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) -def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): - # Copied from fairseq for no_repeat_ngram in beam_search""" +def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: + """Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < no_repeat_ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet return [[] for _ in range(num_hypos)] @@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module): self.summary = nn.Linear(config.hidden_size, num_classes) activation_string = getattr(config, "summary_activation", None) - self.activation = ( - get_activation(activation_string) if activation_string else Identity() - ) # type: typing.Callable + self.activation: Callable = (get_activation(activation_string) if activation_string else Identity()) self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py index 82659156973..cb54a054e3d 100644 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -479,23 +479,7 @@ class XLMModel(XLMPreTrainedModel): # langs = langs.transpose(0, 1) # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.n_layers + head_mask = self.get_head_mask(head_mask, self.config.n_layers) # do not recompute cached elements if cache is not None and input_ids is not None: diff --git a/templates/adding_a_new_model/modeling_xxx.py b/templates/adding_a_new_model/modeling_xxx.py index a92f3cbe550..60529ce870b 100644 --- a/templates/adding_a_new_model/modeling_xxx.py +++ b/templates/adding_a_new_model/modeling_xxx.py @@ -349,10 +349,12 @@ class XxxModel(XxxPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We create a 3D attention mask from a 2D tensor mask. + # (this can be done with self.invert_attention_mask) # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -368,19 +370,7 @@ class XxxModel(XxxPreTrainedModel): # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) ################################## # Replace this with your model code