mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[RoBERTa] model conversion, inference, tests 🔥
This commit is contained in:
parent
44dd941efb
commit
05c083520a
@ -12,6 +12,7 @@ The library currently contains PyTorch implementations, pre-trained model weight
|
||||
4. **[Transformer-XL](https://github.com/kimiyoung/transformer-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott et al.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).
|
||||
|
||||
|
164
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
Normal file
164
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
Normal file
@ -0,0 +1,164 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert RoBERTa checkpoint."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
|
||||
BertIntermediate, BertLayer,
|
||||
BertModel, BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput)
|
||||
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
|
||||
RobertaForMaskedLM,
|
||||
RobertaModel)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
|
||||
|
||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak roberta's weights to our BERT structure.
|
||||
"""
|
||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||
roberta.eval() # disable dropout
|
||||
config = BertConfig(
|
||||
vocab_size_or_config_json_file=50265,
|
||||
hidden_size=roberta.args.encoder_embed_dim,
|
||||
num_hidden_layers=roberta.args.encoder_layers,
|
||||
num_attention_heads=roberta.args.encoder_attention_heads,
|
||||
intermediate_size=roberta.args.encoder_ffn_embed_dim,
|
||||
max_position_embeddings=514,
|
||||
type_vocab_size=1,
|
||||
)
|
||||
print("Our BERT config:", config)
|
||||
|
||||
model = RobertaForMaskedLM(config)
|
||||
model.eval()
|
||||
|
||||
# Now let's copy all the weights.
|
||||
# Embeddings
|
||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
||||
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
||||
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
||||
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
||||
model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
# Encoder: start of layer
|
||||
layer: BertLayer = model.roberta.encoder.layer[i]
|
||||
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
|
||||
|
||||
### self attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
assert(
|
||||
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
|
||||
)
|
||||
# we use three distinct linear layers so we split the source layer here.
|
||||
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
|
||||
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
|
||||
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
|
||||
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
|
||||
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
|
||||
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
|
||||
|
||||
### self-attention output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
assert(
|
||||
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||
)
|
||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
||||
self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps
|
||||
|
||||
### intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
assert(
|
||||
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||
)
|
||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||
|
||||
### output
|
||||
bert_output: BertOutput = layer.output
|
||||
assert(
|
||||
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||
)
|
||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
||||
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
|
||||
#### end of layer
|
||||
|
||||
# LM Head
|
||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
|
||||
model.lm_head.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
|
||||
our_output = model(input_ids)[0]
|
||||
their_output = roberta.model(input_ids)[0]
|
||||
print(our_output.shape, their_output.shape)
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
print(
|
||||
"Do both models output the same tensors?",
|
||||
"🔥" if success else "💩"
|
||||
)
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--roberta_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the official PyTorch dump.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_roberta_checkpoint_to_pytorch(
|
||||
args.roberta_checkpoint_path,
|
||||
args.pytorch_dump_folder_path
|
||||
)
|
128
pytorch_transformers/modeling_roberta.py
Normal file
128
pytorch_transformers/modeling_roberta.py
Normal file
@ -0,0 +1,128 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch RoBERTa model. """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
|
||||
BertLayerNorm, BertModel,
|
||||
BertPreTrainedModel, gelu)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
|
||||
}
|
||||
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
||||
}
|
||||
|
||||
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(RobertaEmbeddings, self).__init__(config)
|
||||
self.padding_idx = 1
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
# cf. fairseq's `utils.make_positions`
|
||||
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
return super().forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
|
||||
|
||||
class RobertaConfig(BertConfig):
|
||||
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
class RobertaModel(BertModel):
|
||||
"""
|
||||
Same as BertModel with:
|
||||
- a tiny embeddings tweak.
|
||||
- setup for Roberta pretrained models
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaModel, self).__init__(config)
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
|
||||
|
||||
|
||||
class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
"""
|
||||
Roberta Model with a `language modeling` head on top.
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForMaskedLM, self).__init__(config)
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.lm_head = RobertaLMHead(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
||||
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:]
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
class RobertaLMHead(nn.Module):
|
||||
"""Roberta Head for masked language modeling."""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.weight = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
x = gelu(x)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = F.linear(x, self.weight) + self.bias
|
||||
|
||||
return x
|
69
pytorch_transformers/tests/modeling_roberta_test.py
Normal file
69
pytorch_transformers/tests/modeling_roberta_test.py
Normal file
@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers.modeling_roberta import (RobertaForMaskedLM,
|
||||
RobertaModel)
|
||||
|
||||
|
||||
class RobertaModelTest(unittest.TestCase):
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 50265))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[33.8843, -4.3107, 22.7779],
|
||||
[ 4.6533, -2.8099, 13.6252],
|
||||
[ 1.8222, -3.6898, 8.8600]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_inference_no_head(self):
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[-0.0231, 0.0782, 0.0074],
|
||||
[-0.1854, 0.0539, -0.0174],
|
||||
[ 0.0548, 0.0799, 0.1687]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
42
pytorch_transformers/tests/tokenization_roberta_test.py
Normal file
42
pytorch_transformers/tests/tokenization_roberta_test.py
Normal file
@ -0,0 +1,42 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
|
||||
class RobertaTokenizationTest(unittest.TestCase):
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
self.assertListEqual(
|
||||
tokenizer.encode('Hello world!'),
|
||||
[0, 31414, 232, 328, 2]
|
||||
)
|
||||
self.assertListEqual(
|
||||
tokenizer.encode('Hello world! cécé herlolip'),
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
218
pytorch_transformers/tokenization_roberta.py
Normal file
218
pytorch_transformers/tokenization_roberta.py
Normal file
@ -0,0 +1,218 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RoBERTa."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
'dict_file': 'dict.txt',
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'dict_file':
|
||||
{
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'roberta-base': 512,
|
||||
'roberta-large': 512,
|
||||
'roberta-large-mnli': 512,
|
||||
}
|
||||
|
||||
|
||||
SPACE_NORMALIZER = re.compile(r"\s+")
|
||||
|
||||
def tokenize_line(line):
|
||||
line = SPACE_NORMALIZER.sub(" ", line)
|
||||
line = line.strip()
|
||||
return line.split()
|
||||
|
||||
|
||||
class Dictionary(object):
|
||||
"""
|
||||
A mapping from symbols to consecutive integers
|
||||
|
||||
From Facebook's fairseq.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pad='<pad>',
|
||||
eos='</s>',
|
||||
unk='<unk>',
|
||||
bos='<s>',
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def index(self, sym):
|
||||
"""Returns the index of the specified symbol"""
|
||||
assert isinstance(sym, str)
|
||||
if sym in self.indices:
|
||||
return self.indices[sym]
|
||||
return self.unk_index
|
||||
|
||||
def add_symbol(self, word, n=1):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
@classmethod
|
||||
def load(cls, f, ignore_utf_errors=False):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f, ignore_utf_errors)
|
||||
return d
|
||||
|
||||
def add_from_file(self, f, ignore_utf_errors=False):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols
|
||||
to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
if not ignore_utf_errors:
|
||||
with open(f, 'r', encoding='utf-8') as fd:
|
||||
self.add_from_file(fd)
|
||||
else:
|
||||
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception("Incorrect encoding detected in {}, please "
|
||||
"rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
idx = line.rfind(' ')
|
||||
if idx == -1:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
|
||||
word = line[:idx]
|
||||
count = int(line[idx + 1:])
|
||||
self.indices[word] = len(self.symbols)
|
||||
self.symbols.append(word)
|
||||
self.count.append(count)
|
||||
|
||||
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
|
||||
consumer=None, append_eos=True, reverse_order=False):
|
||||
words = line_tokenizer(line)
|
||||
if reverse_order:
|
||||
words = list(reversed(words))
|
||||
nwords = len(words)
|
||||
ids = [0] * (nwords + 1 if append_eos else nwords)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if add_if_not_exist:
|
||||
idx = self.add_symbol(word)
|
||||
else:
|
||||
idx = self.index(word)
|
||||
if consumer is not None:
|
||||
consumer(word, idx)
|
||||
ids[i] = idx
|
||||
if append_eos:
|
||||
ids[nwords] = self.eos_index
|
||||
return ids
|
||||
|
||||
|
||||
|
||||
|
||||
class RobertaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
RoBERTa tokenizer. Peculiarities:
|
||||
- GPT-2 tokenizer with a different integer mapping on top.
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, dict_file,
|
||||
bos_token="<s>", eos_token="</s>", **kwargs):
|
||||
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
|
||||
|
||||
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
self.dictionary = Dictionary.load(dict_file)
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Use GPT-2 Tokenizer """
|
||||
return self.gpt2_tokenizer._tokenize(text)
|
||||
|
||||
def encode(self, text):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
"""
|
||||
gpt2_tokens_joined = " ".join(
|
||||
str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text))
|
||||
)
|
||||
bpe_sentence = '<s> ' + gpt2_tokens_joined + ' </s>'
|
||||
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
return self.dictionary.index(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
symbol = self.dictionary[index]
|
||||
try:
|
||||
idx = int(symbol)
|
||||
return self.gpt2_tokenizer._convert_id_to_token(idx)
|
||||
except:
|
||||
return symbol
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
|
Loading…
Reference in New Issue
Block a user