# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model.""" from __future__ import absolute_import, division, print_function, unicode_literals import logging import os import json import copy from io import open import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import cached_path logger = logging.getLogger(__name__) CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = 'model.ckpt' class PretrainedConfig(object): """ An abstract class to handle dowloading a model pretrained config. """ pretrained_config_archive_map = {} @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Instantiate a PretrainedConfig from a pre-trained model configuration. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `xlnet-large-cased` - a path or url to a pretrained model archive containing: . `config.json` a configuration file for the model cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached. """ cache_dir = kwargs.get('cache_dir', None) kwargs.pop('cache_dir', None) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] else: config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained model configuration file.".format( config_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), config_file)) return None if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info("loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = cls.from_json_file(resolved_config_file) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config {}".format(config)) return config @classmethod def from_dict(cls, json_object): """Constructs a `Config` from a Python dictionary of parameters.""" config = cls(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path): """ Save this instance to a json file.""" with open(json_file_path, "w", encoding='utf-8') as writer: writer.write(self.to_json_string()) class PreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ config_class = PretrainedConfig pretrained_model_archive_map = {} pretrained_config_archive_map = {} load_tf_weights = lambda model, config, path: None base_model_prefix = "" def __init__(self, config, *inputs, **kwargs): super(PreTrainedModel, self).__init__() if not isinstance(config, PretrainedConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load, or - a path or url to a pretrained model archive containing: . `config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance - a path or url to a tensorflow pretrained model checkpoint containing: . `config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific XLNet class (ex: num_labels for XLNetForSequenceClassification) """ state_dict = kwargs.get('state_dict', None) kwargs.pop('state_dict', None) cache_dir = kwargs.get('cache_dir', None) kwargs.pop('cache_dir', None) from_tf = kwargs.get('from_tf', False) kwargs.pop('from_tf', None) if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] else: if from_tf: # Directly load from a TensorFlow checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) else: archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file)) return None try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained model configuration file.".format( config_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), config_file)) return None if resolved_archive_file == archive_file and resolved_config_file == config_file: logger.info("loading weights file {}".format(archive_file)) logger.info("loading configuration file {}".format(config_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) logger.info("loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = cls.config_class.from_json_file(resolved_config_file) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' # Load from a PyTorch state_dict missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Be able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied return model def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: if dim == 1: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True if layer.bias is not None: new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer class Conv1D(nn.Module): """ Conv1D layer as defined by Alec Radford for GPT (and also used in GPT-2) Basically works like a Linear layer but the weights are transposed """ def __init__(self, nf, nx): super(Conv1D, self).__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = nn.Parameter(w) self.bias = nn.Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x def prune_conv1d_layer(layer, index, dim=1): """ Prune a Conv1D layer (a model parameters) to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if dim == 0: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer def prune_layer(layer, index, dim=None): """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ if isinstance(layer, nn.Linear): return prune_linear_layer(layer, index, dim=0 if dim is None else dim) elif isinstance(layer, Conv1D): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__))