mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
xlm
This commit is contained in:
parent
708877958a
commit
288be7b7ea
73
pytorch_pretrained_bert/convert_xlm_checkpoint_to_pytorch.py
Executable file
73
pytorch_pretrained_bert/convert_xlm_checkpoint_to_pytorch.py
Executable file
@ -0,0 +1,73 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert OpenAI GPT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
|
||||
from pytorch_pretrained_bert.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
|
||||
from pytorch_pretrained_bert.tokenization_xlm import MERGES_NAME, VOCAB_NAME
|
||||
|
||||
|
||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||
# Load checkpoint
|
||||
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
|
||||
|
||||
model = chkpt['model']
|
||||
|
||||
config = chkpt['params']
|
||||
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray)))
|
||||
|
||||
vocab = chkpt['dico_word2id']
|
||||
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in d.items())
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
||||
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model, pytorch_weights_dump_path)
|
||||
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(config, indent=2) + "\n")
|
||||
|
||||
print("Save vocab file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(vocab, indent=2) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--xlm_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the official PyTorch dump.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -72,29 +72,22 @@ class XLMConfig(PretrainedConfig):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
causal=True,
|
||||
d_model=1024,
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
d_inner=4096,
|
||||
ff_activation="gelu",
|
||||
untie_r=True,
|
||||
attn_type="bi",
|
||||
|
||||
n_special=0,
|
||||
emb_dim=2048,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
gelu_activation=True,
|
||||
sinusoidal_embeddings=False,
|
||||
asm=False,
|
||||
id2lang={ 0: "en" },
|
||||
lang2id={ "en": 0 },
|
||||
n_langs=1,
|
||||
n_words=30145,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
|
||||
dropout=0.1,
|
||||
dropatt=0.1,
|
||||
init="normal",
|
||||
init_range=0.1,
|
||||
init_std=0.02,
|
||||
mem_len=None,
|
||||
reuse_len=None,
|
||||
bi_data=False,
|
||||
clamp_len=-1,
|
||||
same_length=False):
|
||||
**kwargs):
|
||||
"""Constructs XLMConfig.
|
||||
|
||||
Args:
|
||||
@ -137,6 +130,8 @@ class XLMConfig(PretrainedConfig):
|
||||
-1 means no clamping.
|
||||
same_length: bool, whether to use the same attention length for each token.
|
||||
"""
|
||||
super(XLMConfig, self).__init__(**kwargs)
|
||||
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
@ -144,36 +139,41 @@ class XLMConfig(PretrainedConfig):
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.n_token = vocab_size_or_config_json_file
|
||||
self.causal = causal
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
assert d_model % n_head == 0
|
||||
self.d_head = d_model // n_head
|
||||
self.ff_activation = ff_activation
|
||||
self.d_inner = d_inner
|
||||
self.untie_r = untie_r
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.n_words = vocab_size_or_config_json_file
|
||||
self.n_special = n_special
|
||||
self.emb_dim = emb_dim
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.gelu_activation = gelu_activation
|
||||
self.sinusoidal_embeddings = sinusoidal_embeddings
|
||||
self.asm = asm
|
||||
self.id2lang = id2lang
|
||||
self.lang2id = lang2id
|
||||
self.n_langs = n_langs
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.init = init
|
||||
self.init_range = init_range
|
||||
self.init_std = init_std
|
||||
self.dropout = dropout
|
||||
self.dropatt = dropatt
|
||||
self.mem_len = mem_len
|
||||
self.reuse_len = reuse_len
|
||||
self.bi_data = bi_data
|
||||
self.clamp_len = clamp_len
|
||||
self.same_length = same_length
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@property
|
||||
def total_tokens_embeddings(self):
|
||||
return self.n_words + self.n_special
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.emb_dim
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layers
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm
|
||||
@ -259,9 +259,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, n_heads, dim, dropout):
|
||||
def __init__(self, n_heads, dim, dropout, output_attentions=False):
|
||||
super().__init__()
|
||||
self.layer_id = next(MultiHeadAttention.NEW_ID)
|
||||
self.output_attentions = output_attentions
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.dropout = dropout
|
||||
@ -325,7 +326,10 @@ class MultiHeadAttention(nn.Module):
|
||||
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
|
||||
return self.out_lin(context)
|
||||
outputs = (self.out_lin(context),)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (weights)
|
||||
return outputs
|
||||
|
||||
|
||||
class TransformerFFN(nn.Module):
|
||||
@ -345,52 +349,6 @@ class TransformerFFN(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
|
||||
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_len = max_len - 1 # ignoring <BOS>
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty
|
||||
|
||||
|
||||
class XLMPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
@ -410,16 +368,11 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, XLMLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, XLMRelativeAttention):
|
||||
for param in [module.q, module.k, module.v, module.o, module.r,
|
||||
module.r_r_bias, module.r_s_bias, module.r_w_bias,
|
||||
module.seg_embed]:
|
||||
param.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class XLMModel(XLMPreTrainedModel):
|
||||
@ -429,7 +382,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
|
||||
'asm_cutoffs', 'asm_div_value']
|
||||
|
||||
def __init__(self, params, output_attentions=False, output_hidden_states=False): #, dico, is_encoder, with_output):
|
||||
def __init__(self, config): #, dico, is_encoder, with_output):
|
||||
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
|
||||
Paper: https://arxiv.org/abs/1901.07291
|
||||
Original code: https://github.com/facebookresearch/XLM
|
||||
@ -481,41 +434,41 @@ class XLMModel(XLMPreTrainedModel):
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
super(XLMModel, self).__init__(params)
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
super(XLMModel, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
# encoder / decoder, output layer
|
||||
# self.is_encoder = is_encoder
|
||||
# self.is_decoder = not is_encoder
|
||||
# self.with_output = with_output
|
||||
self.causal = params.causal
|
||||
self.causal = config.causal
|
||||
|
||||
# dictionary / languages
|
||||
self.n_langs = params.n_langs
|
||||
self.n_words = params.n_words
|
||||
self.eos_index = params.eos_index
|
||||
self.pad_index = params.pad_index
|
||||
self.n_langs = config.n_langs
|
||||
self.n_words = config.n_words
|
||||
self.eos_index = config.eos_index
|
||||
self.pad_index = config.pad_index
|
||||
# self.dico = dico
|
||||
self.id2lang = params.id2lang
|
||||
self.lang2id = params.lang2id
|
||||
self.id2lang = config.id2lang
|
||||
self.lang2id = config.lang2id
|
||||
# assert len(self.dico) == self.n_words
|
||||
assert len(self.id2lang) == len(self.lang2id) == self.n_langs
|
||||
|
||||
# model parameters
|
||||
self.dim = params.emb_dim # 512 by default
|
||||
self.dim = config.emb_dim # 512 by default
|
||||
self.hidden_dim = self.dim * 4 # 2048 by default
|
||||
self.n_heads = params.n_heads # 8 by default
|
||||
self.n_layers = params.n_layers
|
||||
self.dropout = params.dropout
|
||||
self.attention_dropout = params.attention_dropout
|
||||
self.n_heads = config.n_heads # 8 by default
|
||||
self.n_layers = config.n_layers
|
||||
self.dropout = config.dropout
|
||||
self.attention_dropout = config.attention_dropout
|
||||
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
|
||||
|
||||
# embeddings
|
||||
self.position_embeddings = Embedding(params.max_position_embeddings, self.dim)
|
||||
if params.sinusoidal_embeddings:
|
||||
create_sinusoidal_embeddings(params.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||
if params.n_langs > 1:
|
||||
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim)
|
||||
if config.sinusoidal_embeddings:
|
||||
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||
if config.n_langs > 1:
|
||||
self.lang_embeddings = Embedding(self.n_langs, self.dim)
|
||||
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
||||
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
|
||||
@ -535,26 +488,26 @@ class XLMModel(XLMPreTrainedModel):
|
||||
if self.is_decoder:
|
||||
self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
|
||||
self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation))
|
||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=config.gelu_activation))
|
||||
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
|
||||
|
||||
def forward(self, x, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
|
||||
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
|
||||
"""
|
||||
Inputs:
|
||||
`x` LongTensor(bs, slen), containing word indices
|
||||
`input_ids` LongTensor(bs, slen), containing word indices
|
||||
`lengths` LongTensor(bs), containing the length of each sentence
|
||||
`causal` Boolean, if True, the attention is only done over previous hidden states
|
||||
`positions` LongTensor(bs, slen), containing word positions
|
||||
`langs` LongTensor(bs, slen), containing language IDs
|
||||
"""
|
||||
# lengths = (x != self.pad_index).float().sum(dim=1)
|
||||
# mask = x != self.pad_index
|
||||
# lengths = (input_ids != self.pad_index).float().sum(dim=1)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
bs, slen = x.size()
|
||||
bs, slen = input_ids.size()
|
||||
assert lengths.size(0) == bs
|
||||
assert lengths.max().item() <= slen
|
||||
# x = x.transpose(0, 1) # batch size as dimension 0
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
# if src_enc is not None:
|
||||
# assert self.is_decoder
|
||||
@ -567,7 +520,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
# positions
|
||||
if positions is None:
|
||||
positions = x.new(slen).long()
|
||||
positions = input_ids.new(slen).long()
|
||||
positions = torch.arange(slen, out=positions).unsqueeze(0)
|
||||
else:
|
||||
assert positions.size() == (bs, slen) # (slen, bs)
|
||||
@ -581,7 +534,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
# do not recompute cached elements
|
||||
if cache is not None:
|
||||
_slen = slen - cache['slen']
|
||||
x = x[:, -_slen:]
|
||||
input_ids = input_ids[:, -_slen:]
|
||||
positions = positions[:, -_slen:]
|
||||
if langs is not None:
|
||||
langs = langs[:, -_slen:]
|
||||
@ -589,7 +542,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
attn_mask = attn_mask[:, -_slen:]
|
||||
|
||||
# embeddings
|
||||
tensor = self.embeddings(x)
|
||||
tensor = self.embeddings(input_ids)
|
||||
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
|
||||
if langs is not None:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
@ -648,21 +601,21 @@ class XLMPredLayer(nn.Module):
|
||||
"""
|
||||
Prediction layer (cross_entropy or adaptive_softmax).
|
||||
"""
|
||||
def __init__(self, params):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.asm = params.asm
|
||||
self.n_words = params.n_words
|
||||
self.pad_index = params.pad_index
|
||||
dim = params.emb_dim
|
||||
self.asm = config.asm
|
||||
self.n_words = config.n_words
|
||||
self.pad_index = config.pad_index
|
||||
dim = config.emb_dim
|
||||
|
||||
if params.asm is False:
|
||||
self.proj = Linear(dim, params.n_words, bias=True)
|
||||
if config.asm is False:
|
||||
self.proj = Linear(dim, config.n_words, bias=True)
|
||||
else:
|
||||
self.proj = nn.AdaptiveLogSoftmaxWithLoss(
|
||||
in_features=dim,
|
||||
n_classes=params.n_words,
|
||||
cutoffs=params.asm_cutoffs,
|
||||
div_value=params.asm_div_value,
|
||||
n_classes=config.n_words,
|
||||
cutoffs=config.asm_cutoffs,
|
||||
div_value=config.asm_div_value,
|
||||
head_bias=True, # default is False
|
||||
)
|
||||
|
||||
@ -691,66 +644,63 @@ class XLMPredLayer(nn.Module):
|
||||
|
||||
|
||||
class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
|
||||
Paper: https://arxiv.org/abs/1901.07291
|
||||
Original code: https://github.com/facebookresearch/XLM
|
||||
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
|
||||
Paper: https://arxiv.org/abs/1901.07291
|
||||
Original code: https://github.com/facebookresearch/XLM
|
||||
|
||||
Params:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
Params:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLM paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLM paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
|
||||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
||||
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
||||
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
|
||||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
||||
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
||||
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLMModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||
model = modeling.XLMModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLMLMHeadModel, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
|
||||
self.attn_type = config.attn_type
|
||||
self.same_length = config.same_length
|
||||
|
||||
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||
self.transformer = XLMModel(config)
|
||||
self.pred_layer = XLMPredLayer(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
@ -761,7 +711,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
"""
|
||||
self.pred_layer.proj.weight = self.transformer.embeddings.weight
|
||||
|
||||
def forward(self, x, lengths, positions=None, langs=None, cache=None,
|
||||
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None,
|
||||
labels=None, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
@ -789,7 +739,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||
to pool the input to get a vector representation.
|
||||
"""
|
||||
transformer_outputs = self.transformer(x, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
logits = self.pred_layer(output, labels)
|
||||
@ -905,18 +855,12 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
||||
output_attentions=False, output_hidden_states=False):
|
||||
def __init__(self, config):
|
||||
super(XLMForSequenceClassification, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
|
||||
self.summary_type = summary_type
|
||||
self.num_labels = num_labels
|
||||
self.transformer = XLMModel(config)
|
||||
|
||||
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||
|
||||
self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type, use_proj=use_proj)
|
||||
self.sequence_summary = XLMSequenceSummary(config)
|
||||
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
@ -1030,13 +974,12 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||
def __init__(self, CONFIG_NAME):
|
||||
super(XLMForQuestionAnswering, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
|
||||
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.transformer = XLMModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
|
326
pytorch_pretrained_bert/tokenization_xlm.py
Normal file
326
pytorch_pretrained_bert/tokenization_xlm.py
Normal file
@ -0,0 +1,326 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization import BasicTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'xlm-mlm-en-2048': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
INDEX= {
|
||||
"bos_index": 0,
|
||||
"eos_index": 1,
|
||||
"pad_index": 2,
|
||||
"unk_index": 3,
|
||||
"mask_index": 5
|
||||
}
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
word is represented as tuple of symbols (symbols being variable-length strings)
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
def text_standardize(text):
|
||||
"""
|
||||
fixes some issues the spacy tokenizer had on books corpus
|
||||
also does some whitespace standardization
|
||||
"""
|
||||
text = text.replace('—', '-')
|
||||
text = text.replace('–', '-')
|
||||
text = text.replace('―', '-')
|
||||
text = text.replace('…', '...')
|
||||
text = text.replace('´', "'")
|
||||
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
||||
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
class XLMTokenizer(object):
|
||||
"""
|
||||
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
|
||||
- lower case all inputs
|
||||
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
||||
- argument special_tokens and function set_special_tokens:
|
||||
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download vocabulary.".format(
|
||||
vocab_file))
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file, merges_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
||||
never_split=special_tokens if special_tokens is not None else [])
|
||||
self.fix_text = None
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
merges = [tuple(merge.split()[:2]) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
||||
self.nlp.never_split = special_tokens
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
if word == '\n </w>':
|
||||
word = '\n</w>'
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer
|
||||
text = self.nlp.tokenize(text)
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||
else:
|
||||
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
Loading…
Reference in New Issue
Block a user