import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.models.bert.modeling_bert import ( BERT_INPUTS_DOCSTRING, BERT_START_DOCSTRING, BertEmbeddings, BertLayer, BertPooler, BertPreTrainedModel, ) def entropy(x): """Calculate entropy of a pre-softmax logit Tensor""" exp_x = torch.exp(x) A = torch.sum(exp_x, dim=1) # sum of exp(x_i) B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i) return torch.log(A) - B / A class DeeBertEncoder(nn.Module): def __init__(self, config): super().__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)]) self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)] def set_early_exit_entropy(self, x): if (type(x) is float) or (type(x) is int): for i in range(len(self.early_exit_entropy)): self.early_exit_entropy[i] = x else: self.early_exit_entropy = x def init_highway_pooler(self, pooler): loaded_model = pooler.state_dict() for highway in self.highway: for name, param in highway.pooler.state_dict().items(): param.copy_(loaded_model[name]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, ): all_hidden_states = () all_attentions = () all_highway_exits = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask ) hidden_states = layer_outputs[0] if self.output_attentions: all_attentions = all_attentions + (layer_outputs[1],) current_outputs = (hidden_states,) if self.output_hidden_states: current_outputs = current_outputs + (all_hidden_states,) if self.output_attentions: current_outputs = current_outputs + (all_attentions,) highway_exit = self.highway[i](current_outputs) # logits, pooled_output if not self.training: highway_logits = highway_exit[0] highway_entropy = entropy(highway_logits) highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy all_highway_exits = all_highway_exits + (highway_exit,) if highway_entropy < self.early_exit_entropy[i]: new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,) raise HighwayException(new_output, i + 1) else: all_highway_exits = all_highway_exits + (highway_exit,) # Add last layer if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: outputs = outputs + (all_attentions,) outputs = outputs + (all_highway_exits,) return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits @add_start_docstrings( "The Bert Model transformer with early exiting (DeeBERT). ", BERT_START_DOCSTRING, ) class DeeBertModel(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = DeeBertEncoder(config) self.pooler = BertPooler(config) self.init_weights() def init_highway_pooler(self): self.encoder.init_highway_pooler(self.pooler) def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, ): r""" Return: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training. This output is usually *not* a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. highway_exits (:obj:`tuple(tuple(torch.Tensor))`: Tuple of each early exit's results (total length: number of layers) Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states. """ if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_attention_mask.dim() == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if encoder_attention_mask.dim() == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_extended_attention_mask.to( dtype=next(self.parameters()).dtype ) # fp16 compatibility encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) outputs = (sequence_output, pooled_output,) + encoder_outputs[ 1: ] # add hidden_states and attentions if they are here return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits class HighwayException(Exception): def __init__(self, message, exit_layer): self.message = message self.exit_layer = exit_layer # start from 1! class BertHighway(nn.Module): """A module to provide a shortcut from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification) """ def __init__(self, config): super().__init__() self.pooler = BertPooler(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) def forward(self, encoder_outputs): # Pooler pooler_input = encoder_outputs[0] pooler_output = self.pooler(pooler_input) # "return" pooler_output # BertModel bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:] # "return" bmodel_output # Dropout and classification pooled_output = bmodel_output[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits, pooled_output @add_start_docstrings( """Bert Model (with early exiting - DeeBERT) with a classifier on top, also takes care of multi-layer training. """, BERT_START_DOCSTRING, ) class DeeBertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.num_layers = config.num_hidden_layers self.bert = DeeBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) self.init_weights() @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_layer=-1, train_highway=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): Classification (or regression if config.num_labels==1) loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. highway_exits (:obj:`tuple(tuple(torch.Tensor))`: Tuple of each early exit's results (total length: number of layers) Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states. """ exit_layer = self.num_layers try: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) # sequence_output, pooled_output, (hidden_states), (attentions), highway exits pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here except HighwayException as e: outputs = e.message exit_layer = e.exit_layer logits = outputs[0] if not self.training: original_entropy = entropy(logits) highway_entropy = [] highway_logits_all = [] if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) # work with highway exits highway_losses = [] for highway_exit in outputs[-1]: highway_logits = highway_exit[0] if not self.training: highway_logits_all.append(highway_logits) highway_entropy.append(highway_exit[2]) if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1)) highway_losses.append(highway_loss) if train_highway: outputs = (sum(highway_losses[:-1]),) + outputs # exclude the final highway, of course else: outputs = (loss,) + outputs if not self.training: outputs = outputs + ((original_entropy, highway_entropy), exit_layer) if output_layer >= 0: outputs = ( (outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:] ) # use the highway of the last layer return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)