# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model.""" from __future__ import (absolute_import, division, print_function, unicode_literals) import copy import json import logging import os from io import open import six import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F from .file_utils import cached_path logger = logging.getLogger(__name__) CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = 'model.ckpt' if not six.PY2: def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = ''.join(docstr) + fn.__doc__ return fn return docstring_decorator else: # Not possible to update class docstrings on python2 def add_start_docstrings(*docstr): def docstring_decorator(fn): return fn return docstring_decorator class PretrainedConfig(object): """ An abstract class to handle dowloading a model pretrained config. """ pretrained_config_archive_map = {} def __init__(self, **kwargs): self.finetuning_task = kwargs.pop('finetuning_task', None) self.num_labels = kwargs.pop('num_labels', 2) self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) def save_pretrained(self, save_directory): """ Save a configuration file to a directory, so that it can be re-loaded using the `from_pretrained(save_directory)` class method. """ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" # If we save using the predefined names, we can load using `from_pretrained` output_config_file = os.path.join(save_directory, CONFIG_NAME) self.to_json_file(output_config_file) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): """ Instantiate a PretrainedConfig from a pre-trained model configuration. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `xlnet-large-cased` - a path or url to a directory containing a configuration file `config.json` for the model, - a path or url to a configuration file for the model. cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached. """ cache_dir = kwargs.pop('cache_dir', None) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) else: config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained model configuration file.".format( config_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), config_file)) return None if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info("loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = cls.from_json_file(resolved_config_file) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config %s", config) return config @classmethod def from_dict(cls, json_object): """Constructs a `Config` from a Python dictionary of parameters.""" config = cls(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) def __eq__(self, other): return self.__dict__ == other.__dict__ def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path): """ Save this instance to a json file.""" with open(json_file_path, "w", encoding='utf-8') as writer: writer.write(self.to_json_string()) class PreTrainedModel(nn.Module): """ An abstract class to handle storing model config and a simple interface for dowloading and loading pretrained models. """ config_class = PretrainedConfig pretrained_model_archive_map = {} load_tf_weights = lambda model, config, path: None base_model_prefix = "" input_embeddings = None def __init__(self, config, *inputs, **kwargs): super(PreTrainedModel, self).__init__() if not isinstance(config, PretrainedConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) # Save config in model self.config = config def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end Args: new_num_tokens: (Optional) New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end If not provided or None: return the provided token Embedding Module. Return: Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None """ if new_num_tokens is None: return old_embeddings old_num_tokens, old_embedding_dim = old_embeddings.weight.size() if old_num_tokens == new_num_tokens: return old_embeddings # Build new embeddings new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) new_embeddings.to(old_embeddings.weight.device) # initialize all new embeddings (in particular added tokens) self.init_weights(new_embeddings) # Copy word embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] return new_embeddings def _tie_or_clone_weights(self, first_module, second_module): """ Tie or clone module weights depending of weither we are using TorchScript or not """ if self.config.torchscript: first_module.weight = nn.Parameter(second_module.weight.clone()) else: first_module.weight = second_module.weight def resize_token_embeddings(self, new_num_tokens=None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Args: new_num_tokens: (Optional) New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end If not provided or None: does nothing. Return: Pointer to the input tokens Embedding Module of the model """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed model_embeds = base_model._resize_token_embeddings(new_num_tokens) if new_num_tokens is None: return model_embeds # Update base model and current model config self.config.vocab_size = new_num_tokens base_model.vocab_size = new_num_tokens # Tie weights again if needed if hasattr(self, 'tie_weights'): self.tie_weights() return model_embeds def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): """ Save a model with its configuration file to a directory, so that it can be re-loaded using the `from_pretrained(save_directory)` class method. """ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" # Only save the model it-self if we are using distributed training model_to_save = self.module if hasattr(self, 'module') else self # Save configuration file model_to_save.config.save_pretrained(save_directory) # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) torch.save(model_to_save.state_dict(), output_model_file) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load, or - a path or url to a pretrained model archive containing: . `config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance - a path or url to a tensorflow pretrained model checkpoint containing: . `config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint config: an optional configuration for the model from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific XLNet class (ex: num_labels for XLNetForSequenceClassification) """ config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) output_loading_info = kwargs.pop('output_loading_info', False) # Load config if config is None: config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) # Load model if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf: # Directly load from a TensorFlow checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") else: archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: if from_tf: # Directly load from a TensorFlow checkpoint archive_file = pretrained_model_name_or_path + ".index" else: archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) # Instantiate model. model = cls(config) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # Load from a PyTorch state_dict missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied if output_loading_info: loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} return model, loading_info return model class Conv1D(nn.Module): def __init__(self, nf, nx): """ Conv1D layer as defined by Alec for GPT (and also used in GPT-2) Basically works like a Linear layer but the weights are transposed """ super(Conv1D, self).__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = nn.Parameter(w) self.bias = nn.Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x class PoolerStartLogits(nn.Module): """ Compute SQuAD start_logits from sequence hidden states. """ def __init__(self, config): super(PoolerStartLogits, self).__init__() self.dense = nn.Linear(config.hidden_size, 1) def forward(self, hidden_states, p_mask=None): """ Args: `p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS) shape [batch_size, seq_len]. 1.0 means token should be masked. """ x = self.dense(hidden_states).squeeze(-1) if p_mask is not None: x = x * (1 - p_mask) - 1e30 * p_mask return x class PoolerEndLogits(nn.Module): """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. """ def __init__(self, config): super(PoolerEndLogits, self).__init__() self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) self.activation = nn.Tanh() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense_1 = nn.Linear(config.hidden_size, 1) def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): """ Args: One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states. `start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states. `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. `p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS) shape [batch_size, seq_len]. 1.0 means token should be masked. """ slen, hsz = hidden_states.shape[-2:] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) x = self.activation(x) x = self.LayerNorm(x) x = self.dense_1(x).squeeze(-1) if p_mask is not None: x = x * (1 - p_mask) - 1e30 * p_mask return x class PoolerAnswerClass(nn.Module): """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ def __init__(self, config): super(PoolerAnswerClass, self).__init__() self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) self.activation = nn.Tanh() self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): """ Args: One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states. `start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states. `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. `cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token. # note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample """ slen, hsz = hidden_states.shape[-2:] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) if cls_index is not None: cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) else: cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) x = self.activation(x) x = self.dense_1(x).squeeze(-1) return x class SQuADHead(nn.Module): """ A SQuAD head inspired by XLNet. """ def __init__(self, config): super(SQuADHead, self).__init__() self.start_n_top = config.start_n_top self.end_n_top = config.end_n_top self.start_logits = PoolerStartLogits(config) self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) def forward(self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None): """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. """ outputs = () start_logits = self.start_logits(hidden_states, p_mask) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, let's remove the dimension added by batch splitting for x in (start_positions, end_positions, cls_index, is_impossible): if x is not None and x.dim() > 1: x.squeeze_(-1) # during training, compute the end logits based on the ground truth of the start position end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) loss_fct = CrossEntropyLoss() start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if cls_index is not None and is_impossible is not None: # Predict answerability from the representation of CLS and START cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) loss_fct_cls = nn.BCEWithLogitsLoss() cls_loss = loss_fct_cls(cls_logits, is_impossible) # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss total_loss += cls_loss * 0.5 outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs else: outputs = (total_loss, start_logits, end_logits) + outputs else: # during inference, compute the end logits based on beam search bsz, slen, hsz = hidden_states.size() start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits # or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits) return outputs class SequenceSummary(nn.Module): """ Compute a single vector summary of a sequence hidden states according to various possibilities: Args of the config class: summary_type: - 'last' => [default] take the last token hidden state (like XLNet) - 'first' => take the first token hidden state (like Bert) - 'mean' => take the mean of all tokens hidden states - 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2) - 'attn' => Not implemented now, use multi-head attention summary_use_proj: Add a projection after the vector extraction summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_first_dropout: Add a dropout before the projection and activation summary_last_dropout: Add a dropout after the projection and activation """ def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if config.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = nn.Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = nn.Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = nn.Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = nn.Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout) def forward(self, hidden_states, token_ids=None): """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. token_ids: [optional] index of the classification token if summary_type == 'token_ids', shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. if summary_type == 'token_ids' and token_ids is None: we take the last token of the sequence as classification token """ if self.summary_type == 'last': output = hidden_states[:, -1] elif self.summary_type == 'first': output = hidden_states[:, 0] elif self.summary_type == 'mean': output = hidden_states.mean(dim=1) elif self.summary_type == 'token_ids': if token_ids is None: token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) else: token_ids = token_ids.unsqueeze(-1).unsqueeze(-1) token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),)) # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) elif self.summary_type == 'attn': raise NotImplementedError output = self.first_dropout(output) output = self.summary(output) output = self.activation(output) output = self.last_dropout(output) return output def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: if dim == 1: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True if layer.bias is not None: new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer def prune_conv1d_layer(layer, index, dim=1): """ Prune a Conv1D layer (a model parameters) to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if dim == 0: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer def prune_layer(layer, index, dim=None): """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ if isinstance(layer, nn.Linear): return prune_linear_layer(layer, index, dim=0 if dim is None else dim) elif isinstance(layer, Conv1D): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__))