mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-27 08:18:58 +06:00
145 lines
5.6 KiB
Python
145 lines
5.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch RoBERTa model. """
|
|
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
|
|
BertLayerNorm, BertModel,
|
|
BertPreTrainedModel, gelu)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
|
|
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
|
|
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
|
|
}
|
|
|
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
|
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
|
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
|
}
|
|
|
|
|
|
class RobertaEmbeddings(BertEmbeddings):
|
|
"""
|
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
"""
|
|
def __init__(self, config):
|
|
super(RobertaEmbeddings, self).__init__(config)
|
|
self.padding_idx = 1
|
|
|
|
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
|
seq_length = input_ids.size(1)
|
|
if position_ids is None:
|
|
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
|
# cf. fairseq's `utils.make_positions`
|
|
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
|
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
|
|
|
|
|
class RobertaConfig(BertConfig):
|
|
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
|
|
class RobertaModel(BertModel):
|
|
"""
|
|
Same as BertModel with:
|
|
- a tiny embeddings tweak.
|
|
- setup for Roberta pretrained models
|
|
"""
|
|
config_class = RobertaConfig
|
|
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
base_model_prefix = "roberta"
|
|
|
|
def __init__(self, config):
|
|
super(RobertaModel, self).__init__(config)
|
|
|
|
self.embeddings = RobertaEmbeddings(config)
|
|
self.apply(self.init_weights)
|
|
|
|
|
|
class RobertaForMaskedLM(BertPreTrainedModel):
|
|
"""
|
|
Roberta Model with a `language modeling` head on top.
|
|
"""
|
|
config_class = RobertaConfig
|
|
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
base_model_prefix = "roberta"
|
|
|
|
def __init__(self, config):
|
|
super(RobertaForMaskedLM, self).__init__(config)
|
|
|
|
self.roberta = RobertaModel(config)
|
|
self.lm_head = RobertaLMHead(config)
|
|
|
|
self.apply(self.init_weights)
|
|
self.tie_weights()
|
|
|
|
def tie_weights(self):
|
|
""" Make sure we are sharing the input and output embeddings.
|
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
|
"""
|
|
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, position_ids=None,
|
|
head_mask=None):
|
|
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask, head_mask=head_mask)
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.lm_head(sequence_output)
|
|
|
|
outputs = (prediction_scores,) + outputs[2:]
|
|
|
|
if masked_lm_labels is not None:
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
|
outputs = (masked_lm_loss,) + outputs
|
|
|
|
return outputs
|
|
|
|
|
|
class RobertaLMHead(nn.Module):
|
|
"""Roberta Head for masked language modeling."""
|
|
|
|
def __init__(self, config):
|
|
super(RobertaLMHead, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = self.dense(features)
|
|
x = gelu(x)
|
|
x = self.layer_norm(x)
|
|
|
|
# project back to size of vocabulary with bias
|
|
x = self.decoder(x) + self.bias
|
|
|
|
return x
|