# coding=utf-8 # Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model with Patience-based Early Exit. """ import logging import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.models.bert.modeling_bert import ( BERT_INPUTS_DOCSTRING, BERT_START_DOCSTRING, BertEncoder, BertModel, BertPreTrainedModel, ) logger = logging.getLogger(__name__) class BertEncoderWithPabee(BertEncoder): def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None): layer_outputs = self.layer[current_layer](hidden_states, attention_mask, head_mask[current_layer]) hidden_states = layer_outputs[0] return hidden_states @add_start_docstrings( "The bare Bert Model transformer with PABEE outputting raw hidden-states without any specific head on top.", BERT_START_DOCSTRING, ) class BertModelWithPabee(BertModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as a decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration set to :obj:`True`; an :obj:`encoder_hidden_states` is expected as an input to the forward pass. .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762 """ def __init__(self, config): super().__init__(config) self.encoder = BertEncoderWithPabee(config) self.init_weights() self.patience = 0 self.inference_instances_num = 0 self.inference_layers_num = 0 self.regression_threshold = 0 def set_regression_threshold(self, threshold): self.regression_threshold = threshold def set_patience(self, patience): self.patience = patience def reset_stats(self): self.inference_instances_num = 0 self.inference_layers_num = 0 def log_stats(self): avg_inf_layers = self.inference_layers_num / self.inference_instances_num message = ( f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up =" f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***" ) print(message) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, output_dropout=None, output_layers=None, regression=False, ): r""" Return: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training. This output is usually *not* a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) encoder_outputs = embedding_output if self.training: res = [] for i in range(self.config.num_hidden_layers): encoder_outputs = self.encoder.adaptive_forward( encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask ) pooled_output = self.pooler(encoder_outputs) logits = output_layers[i](output_dropout(pooled_output)) res.append(logits) elif self.patience == 0: # Use all layers for inference encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, ) pooled_output = self.pooler(encoder_outputs[0]) res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)] else: patient_counter = 0 patient_result = None calculated_layer_num = 0 for i in range(self.config.num_hidden_layers): calculated_layer_num += 1 encoder_outputs = self.encoder.adaptive_forward( encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask ) pooled_output = self.pooler(encoder_outputs) logits = output_layers[i](pooled_output) if regression: labels = logits.detach() if patient_result is not None: patient_labels = patient_result.detach() if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold: patient_counter += 1 else: patient_counter = 0 else: labels = logits.detach().argmax(dim=1) if patient_result is not None: patient_labels = patient_result.detach().argmax(dim=1) if (patient_result is not None) and torch.all(labels.eq(patient_labels)): patient_counter += 1 else: patient_counter = 0 patient_result = logits if patient_counter == self.patience: break res = [patient_result] self.inference_layers_num += calculated_layer_num self.inference_instances_num += 1 return res @add_start_docstrings( """Bert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, ) class BertForSequenceClassificationWithPabee(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModelWithPabee(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifiers = nn.ModuleList( [nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)] ) self.init_weights() @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): Classification (or regression if config.num_labels==1) loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: from transformers import BertTokenizer, BertForSequenceClassification from pabee import BertForSequenceClassificationWithPabee from torch import nn import torch tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-uncased') model = BertForSequenceClassificationWithPabee.from_pretrained('google-bert/bert-base-uncased') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 outputs = model(input_ids, labels=labels) loss, logits = outputs[:2] """ logits = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_dropout=self.dropout, output_layers=self.classifiers, regression=self.num_labels == 1, ) outputs = (logits[-1],) if labels is not None: total_loss = None total_weights = 0 for ix, logits_item in enumerate(logits): if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits_item.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1)) if total_loss is None: total_loss = loss else: total_loss += loss * (ix + 1) total_weights += ix + 1 outputs = (total_loss / total_weights,) + outputs return outputs