mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
1308 lines
67 KiB
Python
1308 lines
67 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch BERT model."""
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from io import open
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss, MSELoss
|
|
|
|
from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
|
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
|
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
|
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
|
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
|
}
|
|
|
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
|
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
|
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
|
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
|
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
|
}
|
|
|
|
|
|
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|
""" Load tf checkpoints in a pytorch model
|
|
"""
|
|
try:
|
|
import re
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
|
raise
|
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
names = []
|
|
arrays = []
|
|
for name, shape in init_vars:
|
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
|
array = tf.train.load_variable(tf_path, name)
|
|
names.append(name)
|
|
arrays.append(array)
|
|
|
|
for name, array in zip(names, arrays):
|
|
name = name.split('/')
|
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
# which are not required for using pretrained model
|
|
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
|
print("Skipping {}".format("/".join(name)))
|
|
continue
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
|
l = re.split(r'_(\d+)', m_name)
|
|
else:
|
|
l = [m_name]
|
|
if l[0] == 'kernel' or l[0] == 'gamma':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif l[0] == 'output_bias' or l[0] == 'beta':
|
|
pointer = getattr(pointer, 'bias')
|
|
elif l[0] == 'output_weights':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif l[0] == 'squad':
|
|
pointer = getattr(pointer, 'classifier')
|
|
else:
|
|
try:
|
|
pointer = getattr(pointer, l[0])
|
|
except AttributeError:
|
|
print("Skipping {}".format("/".join(name)))
|
|
continue
|
|
if len(l) >= 2:
|
|
num = int(l[1])
|
|
pointer = pointer[num]
|
|
if m_name[-11:] == '_embeddings':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif m_name == 'kernel':
|
|
array = np.transpose(array)
|
|
try:
|
|
assert pointer.shape == array.shape
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
print("Initialize PyTorch weight {}".format(name))
|
|
pointer.data = torch.from_numpy(array)
|
|
return model
|
|
|
|
|
|
def gelu(x):
|
|
"""Implementation of the gelu activation function.
|
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
Also see https://arxiv.org/abs/1606.08415
|
|
"""
|
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
|
|
|
|
|
def swish(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
|
|
|
|
|
class BertConfig(PretrainedConfig):
|
|
r"""
|
|
:class:`~pytorch_pretrained_bert.BertConfig` is the configuration class to store the configuration of a
|
|
`BertModel`.
|
|
|
|
Arguments:
|
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
|
hidden_size: Size of the encoder layers and the pooler layer.
|
|
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
|
num_attention_heads: Number of attention heads for each attention layer in
|
|
the Transformer encoder.
|
|
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
|
layer in the Transformer encoder.
|
|
hidden_act: The non-linear activation function (function or string) in the
|
|
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
|
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
|
layers in the embeddings, encoder, and pooler.
|
|
attention_probs_dropout_prob: The dropout ratio for the attention
|
|
probabilities.
|
|
max_position_embeddings: The maximum sequence length that this model might
|
|
ever be used with. Typically set this to something large just in case
|
|
(e.g., 512 or 1024 or 2048).
|
|
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
|
`BertModel`.
|
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
|
initializing all weight matrices.
|
|
layer_norm_eps: The epsilon used by LayerNorm.
|
|
"""
|
|
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
|
|
def __init__(self,
|
|
vocab_size_or_config_json_file=30522,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=2,
|
|
initializer_range=0.02,
|
|
layer_norm_eps=1e-12,
|
|
**kwargs):
|
|
"""Constructs BertConfig.
|
|
"""
|
|
super(BertConfig, self).__init__(**kwargs)
|
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
|
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
|
json_config = json.loads(reader.read())
|
|
for key, value in json_config.items():
|
|
self.__dict__[key] = value
|
|
elif isinstance(vocab_size_or_config_json_file, int):
|
|
self.vocab_size = vocab_size_or_config_json_file
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.hidden_act = hidden_act
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.initializer_range = initializer_range
|
|
self.layer_norm_eps = layer_norm_eps
|
|
else:
|
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
|
"or the path to a pretrained model config file (str)")
|
|
|
|
|
|
try:
|
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
|
except ImportError:
|
|
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
|
class BertLayerNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-12):
|
|
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
|
"""
|
|
super(BertLayerNorm, self).__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, x):
|
|
u = x.mean(-1, keepdim=True)
|
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
return self.weight * x + self.bias
|
|
|
|
class BertEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word, position and token_type embeddings.
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertEmbeddings, self).__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
|
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
# any TensorFlow checkpoint file
|
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, input_ids, token_type_ids=None):
|
|
seq_length = input_ids.size(1)
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
|
|
words_embeddings = self.word_embeddings(input_ids)
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class BertSelfAttention(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertSelfAttention, self).__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0:
|
|
raise ValueError(
|
|
"The hidden size (%d) is not a multiple of the number of attention "
|
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
|
self.output_attentions = config.output_attentions
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
|
|
def transpose_for_scores(self, x):
|
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(self, hidden_states, attention_mask, head_mask=None):
|
|
mixed_query_layer = self.query(hidden_states)
|
|
mixed_key_layer = self.key(hidden_states)
|
|
mixed_value_layer = self.value(hidden_states)
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
|
return outputs
|
|
|
|
|
|
class BertSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertSelfOutput, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertAttention(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertAttention, self).__init__()
|
|
self.self = BertSelfAttention(config)
|
|
self.output = BertSelfOutput(config)
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
|
for head in heads:
|
|
mask[head] = 0
|
|
mask = mask.view(-1).contiguous().eq(1)
|
|
index = torch.arange(len(mask))[mask].long()
|
|
# Prune linear layers
|
|
self.self.query = prune_linear_layer(self.self.query, index)
|
|
self.self.key = prune_linear_layer(self.self.key, index)
|
|
self.self.value = prune_linear_layer(self.self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
# Update hyper params
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
|
|
def forward(self, input_tensor, attention_mask, head_mask=None):
|
|
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
|
attention_output = self.output(self_outputs[0], input_tensor)
|
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class BertIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertIntermediate, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertOutput, self).__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertLayer, self).__init__()
|
|
self.attention = BertAttention(config)
|
|
self.intermediate = BertIntermediate(config)
|
|
self.output = BertOutput(config)
|
|
|
|
def forward(self, hidden_states, attention_mask, head_mask=None):
|
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class BertEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertEncoder, self).__init__()
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
|
|
|
def forward(self, hidden_states, attention_mask, head_mask=None):
|
|
all_hidden_states = ()
|
|
all_attentions = ()
|
|
for i, layer_module in enumerate(self.layer):
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if self.output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
# Add last layer
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
outputs = (hidden_states,)
|
|
if self.output_hidden_states:
|
|
outputs = outputs + (all_hidden_states,)
|
|
if self.output_attentions:
|
|
outputs = outputs + (all_attentions,)
|
|
return outputs # outputs, (hidden states), (attentions)
|
|
|
|
|
|
class BertPooler(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertPooler, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states):
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertPredictionHeadTransform, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.transform_act_fn = config.hidden_act
|
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module):
|
|
def __init__(self, config, bert_model_embedding_weights):
|
|
super(BertLMPredictionHead, self).__init__()
|
|
self.transform = BertPredictionHeadTransform(config)
|
|
self.torchscript = config.torchscript
|
|
|
|
# The output weights are the same as the input embeddings, but there is
|
|
# an output-only bias for each token.
|
|
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
|
bert_model_embedding_weights.size(0),
|
|
bias=False)
|
|
|
|
if self.torchscript:
|
|
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
|
else:
|
|
self.decoder.weight = bert_model_embedding_weights
|
|
|
|
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states) + self.bias
|
|
return hidden_states
|
|
|
|
|
|
class BertOnlyMLMHead(nn.Module):
|
|
def __init__(self, config, bert_model_embedding_weights):
|
|
super(BertOnlyMLMHead, self).__init__()
|
|
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
|
|
|
def forward(self, sequence_output):
|
|
prediction_scores = self.predictions(sequence_output)
|
|
return prediction_scores
|
|
|
|
|
|
class BertOnlyNSPHead(nn.Module):
|
|
def __init__(self, config):
|
|
super(BertOnlyNSPHead, self).__init__()
|
|
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
|
|
def forward(self, pooled_output):
|
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
|
return seq_relationship_score
|
|
|
|
|
|
class BertPreTrainingHeads(nn.Module):
|
|
def __init__(self, config, bert_model_embedding_weights):
|
|
super(BertPreTrainingHeads, self).__init__()
|
|
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
|
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
|
|
def forward(self, sequence_output, pooled_output):
|
|
prediction_scores = self.predictions(sequence_output)
|
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
|
return prediction_scores, seq_relationship_score
|
|
|
|
|
|
class BertPreTrainedModel(PreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
config_class = BertConfig
|
|
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_tf_weights = load_tf_weights_in_bert
|
|
base_model_prefix = "bert"
|
|
|
|
def __init__(self, *inputs, **kwargs):
|
|
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
|
|
def init_weights(self, module):
|
|
""" Initialize the weights.
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
elif isinstance(module, BertLayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
|
|
class BertModel(BertPreTrainedModel):
|
|
r"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
|
|
|
:class:`~pytorch_pretrained_bert.BertModel` is the basic BERT Transformer model with a layer of summed token, \
|
|
position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 \
|
|
for BERT-large). The model is instantiated with the following parameters.
|
|
|
|
Arguments:
|
|
config: a BertConfig class instance with the configuration to build a new model
|
|
output_attentions: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
output_hidden_states: If True, also output hidden states computed by the model at each layer. Default: Fals
|
|
|
|
|
|
Example::
|
|
|
|
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = modeling.BertModel(config=config)
|
|
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertModel, self).__init__(config)
|
|
|
|
self.embeddings = BertEmbeddings(config)
|
|
self.encoder = BertEncoder(config)
|
|
self.pooler = BertPooler(config)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
See base class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
|
|
Arguments:
|
|
input_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the \
|
|
vocabulary(see the tokens pre-processing logic in the scripts `run_bert_extract_features.py`, \
|
|
`run_bert_classifier.py` and `run_bert_squad.py`)
|
|
token_type_ids: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
|
|
a `sentence B` token (see BERT paper for more details).
|
|
attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices \
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when \
|
|
a batch has varying length sentences.
|
|
output_all_encoded_layers: boolean which controls the content of the `encoded_layers` output as described \
|
|
below. Default: `True`.
|
|
head_mask: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 \
|
|
and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 \
|
|
=> head is not masked.
|
|
|
|
|
|
Returns:
|
|
A tuple composed of (encoded_layers, pooled_output). Encoded layers are controlled by the \
|
|
``output_all_encoded_layers`` argument.
|
|
|
|
If ``output_all_encoded_layers`` is set to True, outputs a list of the full sequences of \
|
|
encoded-hidden-states at the end of each attention \
|
|
block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a\
|
|
torch.FloatTensor of size [batch_size, sequence_length, hidden_size].
|
|
|
|
If set to False, outputs only the full sequence of hidden-states corresponding \
|
|
to the last attention block of shape [batch_size, sequence_length, hidden_size].
|
|
|
|
``pooled_output`` is a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a \
|
|
classifier pretrained on top of the hidden state associated to the first character of the \
|
|
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
|
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
|
|
|
|
|
"""
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * self.config.num_hidden_layers
|
|
|
|
embedding_output = self.embeddings(input_ids, token_type_ids)
|
|
encoder_outputs = self.encoder(embedding_output,
|
|
extended_attention_mask,
|
|
head_mask=head_mask)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForPreTraining(BertPreTrainedModel):
|
|
"""BERT model with pre-training heads.
|
|
This module comprises the BERT model followed by the two pre-training heads:
|
|
- the masked language modeling head, and
|
|
- the next sentence classification head.
|
|
|
|
Args:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
|
|
Example ::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForPreTraining(config)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForPreTraining, self).__init__(config)
|
|
|
|
self.bert = BertModel(config)
|
|
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
|
next_sentence_label=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Args:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
|
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
|
is only computed for the labels set in [0, ..., vocab_size]
|
|
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
|
|
with indices selected in [0, 1].
|
|
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
|
|
Returns:
|
|
Either a torch.Tensor or tuple(torch.Tensor, torch.Tensor).
|
|
|
|
if ``masked_lm_labels`` and ``next_sentence_label`` are not ``None``, outputs the total_loss which is the \
|
|
sum of the masked language modeling loss and the next \
|
|
sentence classification loss.
|
|
|
|
if ``masked_lm_labels`` or ``next_sentence_label` is `None``, outputs a tuple comprising:
|
|
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
|
- the next sentence classification logits of shape [batch_size, 2].
|
|
|
|
Example ::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForPreTraining(config)
|
|
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
masked_lm_logits_scores, seq_relationship_logits = model.forward(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
|
|
sequence_output, pooled_output = outputs[:2]
|
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
|
|
|
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
|
total_loss = masked_lm_loss + next_sentence_loss
|
|
outputs = (total_loss,) + outputs
|
|
|
|
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForMaskedLM(BertPreTrainedModel):
|
|
"""BERT model with the masked language modeling head.
|
|
This module comprises the BERT model followed by the masked language modeling head.
|
|
|
|
Args:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
|
|
Example::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForMaskedLM(config)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForMaskedLM, self).__init__(config)
|
|
|
|
self.bert = BertModel(config)
|
|
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Args:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
|
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
|
is only computed for the labels set in [0, ..., vocab_size]
|
|
`head_mask`: an optional torch.LongTensor of shape [num_heads] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
Returns:
|
|
Masked language modeling loss if `masked_lm_labels` is specified, masked language modeling
|
|
logits of shape [batch_size, sequence_length, vocab_size] otherwise.
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
masked_lm_logits_scores = model.forward(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here
|
|
if masked_lm_labels is not None:
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
|
outputs = (masked_lm_loss,) + outputs
|
|
|
|
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForNextSentencePrediction(BertPreTrainedModel):
|
|
"""BERT model with next sentence prediction head.
|
|
This module comprises the BERT model followed by the next sentence classification head.
|
|
|
|
Args:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
|
|
Example::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForNextSentencePrediction(config)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForNextSentencePrediction, self).__init__(config)
|
|
|
|
self.bert = BertModel(config)
|
|
self.cls = BertOnlyNSPHead(config)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Args:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens pre-processing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
|
|
with indices selected in [0, 1].
|
|
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between
|
|
0 and 1.It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked,
|
|
0.0 => head is not masked.
|
|
|
|
Returns:
|
|
If `next_sentence_label` is specified, outputs the total_loss which is the sum of the masked language \
|
|
modeling loss and the next sentence classification loss.
|
|
if `next_sentence_label` is `None`, outputs the next sentence classification logits of shape [batch_size, 2].
|
|
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
seq_relationship_logits = model.forward(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
pooled_output = outputs[1]
|
|
|
|
seq_relationship_score = self.cls(pooled_output)
|
|
|
|
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
|
if next_sentence_label is not None:
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
|
outputs = (next_sentence_loss,) + outputs
|
|
|
|
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForSequenceClassification(BertPreTrainedModel):
|
|
"""BERT model for classification.
|
|
This module is composed of the BERT model with a linear layer on top of
|
|
the pooled output.
|
|
|
|
Params:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
`num_labels`: the number of classes for the classifier. Default = 2.
|
|
|
|
Example::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
num_labels = 2
|
|
|
|
model = BertForSequenceClassification(config, num_labels)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForSequenceClassification, self).__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.bert = BertModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Parameters:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
|
with indices selected in [0, ..., num_labels].
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
Returns:
|
|
if `labels` is not `None`, outputs the CrossEntropy classification loss of the output with the labels.
|
|
if `labels` is `None`, outputs the classification logits of shape `[batch_size, num_labels]`.
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
logits = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
logits = model.forward(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
if labels is not None:
|
|
if self.num_labels == 1:
|
|
# We are doing regression
|
|
loss_fct = MSELoss()
|
|
loss = loss_fct(logits.view(-1), labels.view(-1))
|
|
else:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForMultipleChoice(BertPreTrainedModel):
|
|
"""BERT model for multiple choice tasks.
|
|
This module is composed of the BERT model with a linear layer on top of the pooled output.
|
|
|
|
Parameters:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
|
|
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
|
|
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForMultipleChoice(config)
|
|
logits = model(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForMultipleChoice, self).__init__(config)
|
|
|
|
self.bert = BertModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Parameters:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
|
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
|
|
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
|
with indices selected in [0, ..., num_choices].
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
Returns:
|
|
if `labels` is not `None`, outputs the CrossEntropy classification loss of the output with the labels.
|
|
if `labels` is `None`, outputs the classification logits of shape [batch_size, num_labels].
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
|
|
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
|
|
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForMultipleChoice(config)
|
|
logits = model(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
""" Input shapes should be [bsz, num choices, seq length] """
|
|
num_choices = input_ids.shape[1]
|
|
|
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
|
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
reshaped_logits = logits.view(-1, num_choices)
|
|
|
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(reshaped_logits, labels)
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForTokenClassification(BertPreTrainedModel):
|
|
"""BERT model for token-level classification.
|
|
This module is composed of the BERT model with a linear layer on top of
|
|
the full hidden state of the last layer.
|
|
|
|
Parameters:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
`num_labels`: the number of classes for the classifier. Default = 2.
|
|
|
|
Example::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
num_labels = 2
|
|
|
|
model = BertForTokenClassification(config, num_labels)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForTokenClassification, self).__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.bert = BertModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
|
"""
|
|
Performs a model forward pass. Can be called by calling the class directly, once it has been instantiated.
|
|
|
|
Parameters:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens pre-processing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
|
with indices selected in [0, ..., num_labels].
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
Returns:
|
|
if `labels` is not `None`, outputs the CrossEntropy classification loss of the output with the labels.
|
|
if `labels` is `None`, outputs the classification logits of shape [batch_size, sequence_length, num_labels].
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
logits = model(input_ids, token_type_ids, input_mask)
|
|
# or
|
|
logits = model.forward(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
# Only keep active parts of the loss
|
|
if attention_mask is not None:
|
|
active_loss = attention_mask.view(-1) == 1
|
|
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
|
active_labels = labels.view(-1)[active_loss]
|
|
loss = loss_fct(active_logits, active_labels)
|
|
else:
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
|
|
|
|
|
class BertForQuestionAnswering(BertPreTrainedModel):
|
|
"""BERT model for Question Answering (span extraction).
|
|
This module is composed of the BERT model with a linear layer on top of
|
|
the sequence output that computes start_logits and end_logits
|
|
|
|
Parameters:
|
|
`config`: a BertConfig class instance with the configuration to build a new model
|
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
|
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
|
|
|
|
Example::
|
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
|
|
model = BertForQuestionAnswering(config)
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertForQuestionAnswering, self).__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.bert = BertModel(config)
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
|
end_positions=None, head_mask=None):
|
|
"""
|
|
Parameters:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
|
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
|
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
|
into account for computing the loss.
|
|
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
|
|
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
|
into account for computing the loss.
|
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
|
|
|
Returns:
|
|
if `start_positions` and `end_positions` are not `None`, outputs the total_loss which is the sum of the
|
|
CrossEntropy loss for the start and end token positions.
|
|
if `start_positions` or `end_positions` is `None`, outputs a tuple of start_logits, end_logits which are the
|
|
logits respectively for the start and end position tokens of shape [batch_size, sequence_length].
|
|
|
|
Example::
|
|
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
|
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
|
"""
|
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions.clamp_(0, ignored_index)
|
|
end_positions.clamp_(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
outputs = (total_loss,) + outputs
|
|
|
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|