mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Merge pull request #1670 from huggingface/templates
Templates and explanation for adding a new model and example script
This commit is contained in:
commit
7f84fc571a
@ -62,6 +62,8 @@ Awesome! Please provide the following information:
|
||||
If you are willing to contribute the model yourself, let us know so we can best
|
||||
guide you.
|
||||
|
||||
We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder.
|
||||
|
||||
### Do you want a new feature (that is not a model)?
|
||||
|
||||
A world-class feature request addresses the following points:
|
||||
@ -81,6 +83,8 @@ A world-class feature request addresses the following points:
|
||||
If your issue is well written we're already 80% of the way there by the time you
|
||||
post it.
|
||||
|
||||
We have added **templates** to guide you in the process of adding a new example script for training or testing the models in the library. You can find them in the [`templates`](./templates) folder.
|
||||
|
||||
## Start contributing! (Pull Requests)
|
||||
|
||||
Before writing code, we strongly advise you to search through the exising PRs or
|
||||
|
@ -122,6 +122,7 @@ At some point in the future, you'll be able to seamlessly move from pre-training
|
||||
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation).
|
||||
9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
10. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||
|
||||
|
5
templates/adding_a_new_example_script/README.md
Normal file
5
templates/adding_a_new_example_script/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# How to add a new example script in 🤗Transformers
|
||||
|
||||
This folder provide a template for adding a new example script implementing a training or inference task with the models in the 🤗Transformers library.
|
||||
|
||||
Currently only examples for PyTorch are provided which are adaptations of the library's SQuAD examples which implement single-GPU and distributed training with gradient accumulation and mixed-precision (using NVIDIA's apex library) to cover a reasonable range of use cases.
|
553
templates/adding_a_new_example_script/run_xxx.py
Normal file
553
templates/adding_a_new_example_script/run_xxx.py
Normal file
@ -0,0 +1,553 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for task XXX."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
|
||||
from transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_results = []
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
pred_file=output_prediction_file,
|
||||
na_prob_file=output_null_log_odds_file)
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--doc_stride", default=128, type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
||||
parser.add_argument("--max_query_length", default=64, type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
parser.add_argument("--verbose_logging", action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
995
templates/adding_a_new_example_script/utils_xxx.py
Normal file
995
templates/adding_a_new_example_script/utils_xxx.py
Normal file
@ -0,0 +1,995 @@
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Load XXX dataset. """
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import collections
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
|
||||
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
||||
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SquadExample(object):
|
||||
"""
|
||||
A single training/test example for the Squad dataset.
|
||||
For examples without an answer, the start and end position are -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.doc_tokens = doc_tokens
|
||||
self.orig_answer_text = orig_answer_text
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
s = ""
|
||||
s += "qas_id: %s" % (self.qas_id)
|
||||
s += ", question_text: %s" % (
|
||||
self.question_text)
|
||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
if self.end_position:
|
||||
s += ", end_position: %d" % (self.end_position)
|
||||
if self.is_impossible:
|
||||
s += ", is_impossible: %r" % (self.is_impossible)
|
||||
return s
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
self.tokens = tokens
|
||||
self.token_to_orig_map = token_to_orig_map
|
||||
self.token_is_max_context = token_is_max_context
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.cls_index = cls_index
|
||||
self.p_mask = p_mask
|
||||
self.paragraph_len = paragraph_len
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
with open(input_file, "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
examples = []
|
||||
for entry in input_data:
|
||||
for paragraph in entry["paragraphs"]:
|
||||
paragraph_text = paragraph["context"]
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
prev_is_whitespace = True
|
||||
for c in paragraph_text:
|
||||
if is_whitespace(c):
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||
|
||||
for qa in paragraph["qas"]:
|
||||
qas_id = qa["id"]
|
||||
question_text = qa["question"]
|
||||
start_position = None
|
||||
end_position = None
|
||||
orig_answer_text = None
|
||||
is_impossible = False
|
||||
if is_training:
|
||||
if version_2_with_negative:
|
||||
is_impossible = qa["is_impossible"]
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
answer_offset = answer["answer_start"]
|
||||
answer_length = len(orig_answer_text)
|
||||
start_position = char_to_word_offset[answer_offset]
|
||||
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||
# Only add answers where the text can be exactly recovered from the
|
||||
# document. If this CAN'T happen it's likely due to weird Unicode
|
||||
# stuff so we will just skip the example.
|
||||
#
|
||||
# Note that this means for training mode, every example is NOT
|
||||
# guaranteed to be preserved.
|
||||
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||
cleaned_answer_text = " ".join(
|
||||
whitespace_tokenize(orig_answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||
actual_text, cleaned_answer_text)
|
||||
continue
|
||||
else:
|
||||
start_position = -1
|
||||
end_position = -1
|
||||
orig_answer_text = ""
|
||||
|
||||
example = SquadExample(
|
||||
qas_id=qas_id,
|
||||
question_text=question_text,
|
||||
doc_tokens=doc_tokens,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible)
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training,
|
||||
cls_token_at_end=False,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
# cnt_pos, cnt_neg = 0, 0
|
||||
# max_N, max_M = 1024, 1024
|
||||
# f = np.zeros((max_N, max_M), dtype=np.float32)
|
||||
|
||||
features = []
|
||||
for (example_index, example) in enumerate(examples):
|
||||
|
||||
# if example_index % 100 == 0:
|
||||
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
|
||||
|
||||
query_tokens = tokenizer.tokenize(example.question_text)
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
query_tokens = query_tokens[0:max_query_length]
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for (i, token) in enumerate(example.doc_tokens):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
tok_start_position = None
|
||||
tok_end_position = None
|
||||
if is_training and example.is_impossible:
|
||||
tok_start_position = -1
|
||||
tok_end_position = -1
|
||||
if is_training and not example.is_impossible:
|
||||
tok_start_position = orig_to_tok_index[example.start_position]
|
||||
if example.end_position < len(example.doc_tokens) - 1:
|
||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
||||
example.orig_answer_text)
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
|
||||
# We can have documents that are longer than the maximum sequence length.
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"DocSpan", ["start", "length"])
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, doc_stride)
|
||||
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = []
|
||||
|
||||
# CLS token at the beginning
|
||||
if not cls_token_at_end:
|
||||
tokens.append(cls_token)
|
||||
segment_ids.append(cls_token_segment_id)
|
||||
p_mask.append(0)
|
||||
cls_index = 0
|
||||
|
||||
# Query
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# Paragraph
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
p_mask.append(0)
|
||||
paragraph_len = doc_span.length
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# CLS token at the end
|
||||
if cls_token_at_end:
|
||||
tokens.append(cls_token)
|
||||
segment_ids.append(cls_token_segment_id)
|
||||
p_mask.append(0)
|
||||
cls_index = len(tokens) - 1 # Index of classification token
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(pad_token)
|
||||
input_mask.append(0 if mask_padding_with_zero else 1)
|
||||
segment_ids.append(pad_token_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not span_is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and
|
||||
tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
span_is_impossible = True
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
if is_training and span_is_impossible:
|
||||
start_position = cls_index
|
||||
end_position = cls_index
|
||||
|
||||
if example_index < 20:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("unique_id: %s" % (unique_id))
|
||||
logger.info("example_index: %s" % (example_index))
|
||||
logger.info("doc_span_index: %s" % (doc_span_index))
|
||||
logger.info("tokens: %s" % " ".join(tokens))
|
||||
logger.info("token_to_orig_map: %s" % " ".join([
|
||||
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
|
||||
logger.info("token_is_max_context: %s" % " ".join([
|
||||
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
|
||||
]))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info(
|
||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
if is_training and span_is_impossible:
|
||||
logger.info("impossible example")
|
||||
if is_training and not span_is_impossible:
|
||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||
logger.info("start_position: %d" % (start_position))
|
||||
logger.info("end_position: %d" % (end_position))
|
||||
logger.info(
|
||||
"answer: %s" % (answer_text))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
unique_id=unique_id,
|
||||
example_index=example_index,
|
||||
doc_span_index=doc_span_index,
|
||||
tokens=tokens,
|
||||
token_to_orig_map=token_to_orig_map,
|
||||
token_is_max_context=token_is_max_context,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
cls_index=cls_index,
|
||||
p_mask=p_mask,
|
||||
paragraph_len=paragraph_len,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible))
|
||||
unique_id += 1
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
|
||||
# The SQuAD annotations are character based. We first project them to
|
||||
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
||||
# often find a "better match". For example:
|
||||
#
|
||||
# Question: What year was John Smith born?
|
||||
# Context: The leader was John Smith (1895-1943).
|
||||
# Answer: 1895
|
||||
#
|
||||
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
||||
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
||||
# the exact answer, 1895.
|
||||
#
|
||||
# However, this is not always possible. Consider the following:
|
||||
#
|
||||
# Question: What country is the top exporter of electornics?
|
||||
# Context: The Japanese electronics industry is the lagest in the world.
|
||||
# Answer: Japan
|
||||
#
|
||||
# In this case, the annotator chose "Japan" as a character sub-span of
|
||||
# the word "Japanese". Since our WordPiece tokenizer does not split
|
||||
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
||||
# in SQuAD, but does happen.
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
return (input_start, input_end)
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
|
||||
# Because of the sliding window approach taken to scoring documents, a single
|
||||
# token can appear in multiple documents. E.g.
|
||||
# Doc: the man went to the store and bought a gallon of milk
|
||||
# Span A: the man went to the
|
||||
# Span B: to the store and bought
|
||||
# Span C: and bought a gallon of
|
||||
# ...
|
||||
#
|
||||
# Now the word 'bought' will have two scores from spans B and C. We only
|
||||
# want to consider the score with "maximum context", which we define as
|
||||
# the *minimum* of its left and right context (the *sum* of left and
|
||||
# right context will always be the same, of course).
|
||||
#
|
||||
# In the example the maximum context for 'bought' would be span C since
|
||||
# it has 1 left context and 3 right context, while span B has 4 left context
|
||||
# and 0 right context.
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
end = doc_span.start + doc_span.length - 1
|
||||
if position < doc_span.start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_span.start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = span_index
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
RawResult = collections.namedtuple("RawResult",
|
||||
["unique_id", "start_logits", "end_logits"])
|
||||
|
||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||
version_2_with_negative, null_score_diff_threshold):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
min_null_feature_index = 0 # the paragraph slice with min null score
|
||||
null_start_logit = 0 # the start logit at the slice with min null score
|
||||
null_end_logit = 0 # the end logit at the slice with min null score
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
if version_2_with_negative:
|
||||
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
||||
if feature_null_score < score_null:
|
||||
score_null = feature_null_score
|
||||
min_null_feature_index = feature_index
|
||||
null_start_logit = result.start_logits[0]
|
||||
null_end_logit = result.end_logits[0]
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index >= len(feature.tokens):
|
||||
continue
|
||||
if end_index >= len(feature.tokens):
|
||||
continue
|
||||
if start_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if end_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
if version_2_with_negative:
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=min_null_feature_index,
|
||||
start_index=0,
|
||||
end_index=0,
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
else:
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
# if we didn't include the empty option in the n-best, include it
|
||||
if version_2_with_negative:
|
||||
if "" not in seen_predictions:
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="",
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
|
||||
# In very rare edge cases we could only have single null prediction.
|
||||
# So we just create a nonce prediction in this case to avoid failure.
|
||||
if len(nbest)==1:
|
||||
nbest.insert(0,
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_logit + entry.end_logit)
|
||||
if not best_non_null_entry:
|
||||
if entry.text:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_logit"] = entry.start_logit
|
||||
output["end_logit"] = entry.end_logit
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
|
||||
if not version_2_with_negative:
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
else:
|
||||
# predict "" iff the null score - the score of best non-null > threshold
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||
best_non_null_entry.end_logit)
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example.qas_id] = ""
|
||||
else:
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
|
||||
with open(output_nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
|
||||
if version_2_with_negative:
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
return all_predictions
|
||||
|
||||
|
||||
# For XLNet (and XLM which uses the same head)
|
||||
RawResultExtended = collections.namedtuple("RawResultExtended",
|
||||
["unique_id", "start_top_log_probs", "start_top_index",
|
||||
"end_top_log_probs", "end_top_index", "cls_logits"])
|
||||
|
||||
|
||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file, orig_data_file,
|
||||
start_n_top, end_n_top, version_2_with_negative,
|
||||
tokenizer, verbose_logging):
|
||||
""" XLNet write prediction logic (more complex than Bert's).
|
||||
Write final predictions to the json file and log-odds of null if needed.
|
||||
|
||||
Requires utils_squad_evaluate.py
|
||||
"""
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index",
|
||||
"start_log_prob", "end_log_prob"])
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||
|
||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
|
||||
cur_null_score = result.cls_logits
|
||||
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
score_null = min(score_null, cur_null_score)
|
||||
|
||||
for i in range(start_n_top):
|
||||
for j in range(end_n_top):
|
||||
start_log_prob = result.start_top_log_probs[i]
|
||||
start_index = result.start_top_index[i]
|
||||
|
||||
j_index = i * end_n_top + j
|
||||
|
||||
end_log_prob = result.end_top_log_probs[j_index]
|
||||
end_index = result.end_top_index[j_index]
|
||||
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index >= feature.paragraph_len - 1:
|
||||
continue
|
||||
if end_index >= feature.paragraph_len - 1:
|
||||
continue
|
||||
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_log_prob=start_log_prob,
|
||||
end_log_prob=end_log_prob))
|
||||
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||
reverse=True)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
|
||||
# XLNet un-tokenizer
|
||||
# Let's keep it simple for now and see if we need all this later.
|
||||
#
|
||||
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||
# paragraph_text = example.paragraph_text
|
||||
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
|
||||
# Previously used Bert untokenizer
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||
verbose_logging)
|
||||
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_log_prob=pred.start_log_prob,
|
||||
end_log_prob=pred.end_log_prob))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||
end_log_prob=-1e6))
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
||||
if not best_non_null_entry:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_log_prob"] = entry.start_log_prob
|
||||
output["end_log_prob"] = entry.end_log_prob
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
assert best_non_null_entry is not None
|
||||
|
||||
score_diff = score_null
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
# note(zhiliny): always predict best_non_null_entry
|
||||
# and the evaluation script will search for the best threshold
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
|
||||
with open(output_nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
|
||||
if version_2_with_negative:
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||
orig_data = json.load(reader)["data"]
|
||||
|
||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
|
||||
out_eval = {}
|
||||
|
||||
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
|
||||
|
||||
return out_eval
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
|
||||
# When we created the data, we kept track of the alignment between original
|
||||
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
||||
# now `orig_text` contains the span of our original text corresponding to the
|
||||
# span that we predicted.
|
||||
#
|
||||
# However, `orig_text` may contain extra characters that we don't want in
|
||||
# our prediction.
|
||||
#
|
||||
# For example, let's say:
|
||||
# pred_text = steve smith
|
||||
# orig_text = Steve Smith's
|
||||
#
|
||||
# We don't want to return `orig_text` because it contains the extra "'s".
|
||||
#
|
||||
# We don't want to return `pred_text` because it's already been normalized
|
||||
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
||||
# our tokenizer does additional normalization like stripping accent
|
||||
# characters).
|
||||
#
|
||||
# What we really want to return is "Steve Smith".
|
||||
#
|
||||
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
||||
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
||||
# can fail in certain cases in which case we just return `orig_text`.
|
||||
|
||||
def _strip_spaces(text):
|
||||
ns_chars = []
|
||||
ns_to_s_map = collections.OrderedDict()
|
||||
for (i, c) in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_chars)] = i
|
||||
ns_chars.append(c)
|
||||
ns_text = "".join(ns_chars)
|
||||
return (ns_text, ns_to_s_map)
|
||||
|
||||
# We first tokenize `orig_text`, strip whitespace from the result
|
||||
# and `pred_text`, and check if they are the same length. If they are
|
||||
# NOT the same length, the heuristic has failed. If they are the same
|
||||
# length, we assume the characters are one-to-one aligned.
|
||||
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
|
||||
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
||||
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
if verbose_logging:
|
||||
logger.info(
|
||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
||||
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
if verbose_logging:
|
||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||
orig_ns_text, tok_ns_text)
|
||||
return orig_text
|
||||
|
||||
# We then project the characters in `pred_text` back to `orig_text` using
|
||||
# the character-to-character alignment.
|
||||
tok_s_to_ns_map = {}
|
||||
for (i, tok_index) in tok_ns_to_s_map.items():
|
||||
tok_s_to_ns_map[tok_index] = i
|
||||
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
ns_start_position = tok_s_to_ns_map[start_position]
|
||||
if ns_start_position in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
|
||||
if orig_start_position is None:
|
||||
if verbose_logging:
|
||||
logger.info("Couldn't map start position")
|
||||
return orig_text
|
||||
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
ns_end_position = tok_s_to_ns_map[end_position]
|
||||
if ns_end_position in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
|
||||
if orig_end_position is None:
|
||||
if verbose_logging:
|
||||
logger.info("Couldn't map end position")
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def _get_best_indexes(logits, n_best_size):
|
||||
"""Get the n-best logits from a list."""
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
|
||||
best_indexes = []
|
||||
for i in range(len(index_and_score)):
|
||||
if i >= n_best_size:
|
||||
break
|
||||
best_indexes.append(index_and_score[i][0])
|
||||
return best_indexes
|
||||
|
||||
|
||||
def _compute_softmax(scores):
|
||||
"""Compute softmax probability over raw logits."""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
max_score = None
|
||||
for score in scores:
|
||||
if max_score is None or score > max_score:
|
||||
max_score = score
|
||||
|
||||
exp_scores = []
|
||||
total_sum = 0.0
|
||||
for score in scores:
|
||||
x = math.exp(score - max_score)
|
||||
exp_scores.append(x)
|
||||
total_sum += x
|
||||
|
||||
probs = []
|
||||
for score in exp_scores:
|
||||
probs.append(score / total_sum)
|
||||
return probs
|
62
templates/adding_a_new_model/README.md
Normal file
62
templates/adding_a_new_model/README.md
Normal file
@ -0,0 +1,62 @@
|
||||
# How to add a new model in 🤗Transformers
|
||||
|
||||
This folder describes the process to add a new model in 🤗Transformers and provide templates for the required files.
|
||||
|
||||
The library is designed to incorporate a variety of models and code bases. As such the process for adding a new model usually mostly consists in copy-pasting to relevant original code in the various sections of the templates included in the present repository.
|
||||
|
||||
One important point though is that the library has the following goals impacting the way models are incorporated:
|
||||
|
||||
- one specific feature of the API is the capability to run the model and tokenizer inline. The tokenization code thus often have to be slightly adapted to allow for running in the python interpreter.
|
||||
- the package is also designed to be as self-consistent and with a small and reliable set of packages dependencies. In consequence, additional dependencies are usually not allowed when adding a model but can be allowed for the inclusion of a new tokenizer (recent examples of dependencies added for tokenizer specificites includes `sentencepiece` and `sacremoses`). Please make sure to check the existing dependencies when possible before adding a new one.
|
||||
|
||||
For a quick overview of the library organization, please check the [QuickStart section of the documentation](https://huggingface.co/transformers/quickstart.html).
|
||||
|
||||
# Typical workflow for including a model
|
||||
|
||||
Here an overview of the general workflow:
|
||||
|
||||
- [ ] add model/configuration/tokenization classes
|
||||
- [ ] add conversion scripts
|
||||
- [ ] add tests
|
||||
- [ ] finalize
|
||||
|
||||
Let's details what should be done at each step
|
||||
|
||||
## Adding model/configuration/tokenization classes
|
||||
|
||||
Here is the workflow for adding model/configuration/tokenization classes:
|
||||
|
||||
- [ ] copy the python files from the present folder to the main folder and rename them, replacing `xxx` with your model name,
|
||||
- [ ] edit the files to replace `XXX` (with various casing) with your model name
|
||||
- [ ] copy-past or create a simple configuration class for your model in the `configuration_...` file
|
||||
- [ ] copy-past or create the code for your model in the `modeling_...` files (PyTorch and TF 2.0)
|
||||
- [ ] copy-past or create a tokenizer class for your model in the `tokenization_...` file
|
||||
|
||||
# Adding conversion scripts
|
||||
|
||||
Here is the workflow for the conversion scripts:
|
||||
|
||||
- [ ] copy the conversion script (`convert_...`) from the present folder to the main folder.
|
||||
- [ ] edit this scipt to convert your original checkpoint weights to the current pytorch ones.
|
||||
|
||||
# Adding tests:
|
||||
|
||||
Here is the workflow for the adding tests:
|
||||
|
||||
- [ ] copy the python files from the `tests` sub-folder of the present folder to the `tests` subfolder of the main folder and rename them, replacing `xxx` with your model name,
|
||||
- [ ] edit the tests files to replace `XXX` (with various casing) with your model name
|
||||
- [ ] edit the tests code as needed
|
||||
|
||||
# Final steps
|
||||
|
||||
You can then finish the addition step by adding imports for your classes in the common files:
|
||||
|
||||
- [ ] add import for all the relevant classes in `__init__.py`
|
||||
- [ ] add your configuration in `configuration_auto.py`
|
||||
- [ ] add your PyTorch and TF 2.0 model respectively in `modeling_auto.py` and `modeling_tf_auto.py`
|
||||
- [ ] add your tokenizer in `tokenization_auto.py`
|
||||
- [ ] add your models and tokenizer to `pipeline.py`
|
||||
- [ ] add a link to your conversion script in the main conversion utility (currently in `__main__` but will be moved to the `commands` subfolder in the near future)
|
||||
- [ ] edit the PyTorch to TF 2.0 conversion script to add your model in the `convert_pytorch_checkpoint_to_tf2.py` file
|
||||
- [ ] add a mention of your model in the doc: `README.md` and the documentation it-self at `docs/source/pretrained_models.rst`.
|
||||
- [ ] upload the pretrained weigths, configurations and vocabulary files.
|
130
templates/adding_a_new_model/configuration_xxx.py
Normal file
130
templates/adding_a_new_model/configuration_xxx.py
Normal file
@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2010, XXX authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" XXX model configuration """
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import six
|
||||
from io import open
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||
}
|
||||
|
||||
|
||||
class XxxConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~transformers.XxxConfig` is the configuration class to store the configuration of a
|
||||
`XxxModel`.
|
||||
|
||||
|
||||
Arguments:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XxxModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`XxxModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
|
||||
num_labels=1,
|
||||
summary_type='cls_index',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs):
|
||||
super(XxxConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, six.string_types) else -1
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
if isinstance(vocab_size_or_config_json_file, six.string_types):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif not isinstance(vocab_size_or_config_json_file, int):
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.n_positions
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.n_embd
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_head
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layer
|
@ -0,0 +1,65 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert XXX checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, xxx_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = XxxConfig.from_json_file(xxx_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = XxxForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_xxx(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--xxx_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained XXX model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.xxx_config_file,
|
||||
args.pytorch_dump_path)
|
500
templates/adding_a_new_model/modeling_tf_xxx.py
Normal file
500
templates/adding_a_new_model/modeling_tf_xxx.py
Normal file
@ -0,0 +1,500 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TF 2.0 XXX model. """
|
||||
|
||||
####################################################
|
||||
# In this template, replace all the XXX (various casings) with your model name
|
||||
####################################################
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_xxx import XxxConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||
}
|
||||
|
||||
####################################################
|
||||
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
|
||||
# - tf.keras.layers.Layer for the layers and
|
||||
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
|
||||
####################################################
|
||||
|
||||
####################################################
|
||||
# Here is an example of typical layer in a TF 2.0 model of the library
|
||||
# The classes are usually identical to the PyTorch ones and prefixed with 'TF'.
|
||||
#
|
||||
# Note that class __init__ parameters includes **kwargs (send to 'super').
|
||||
# This let us have a control on class scope and variable names:
|
||||
# More precisely, we set the names of the class attributes (lower level layers) to
|
||||
# to the equivalent attributes names in the PyTorch model so we can have equivalent
|
||||
# class and scope structure between PyTorch and TF 2.0 models and easily load one in the other.
|
||||
#
|
||||
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
||||
####################################################
|
||||
class TFXxxLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFXxxLayer, self).__init__(**kwargs)
|
||||
self.attention = TFXxxAttention(config, name='attention')
|
||||
self.intermediate = TFXxxIntermediate(config, name='intermediate')
|
||||
self.transformer_output = TFXxxOutput(config, name='output')
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
|
||||
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.transformer_output([intermediate_output, attention_output], training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
####################################################
|
||||
# The full model without a specific pretrained or finetuning head is
|
||||
# provided as a tf.keras.layers.Layer usually called "TFXxxMainLayer"
|
||||
####################################################
|
||||
class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFXxxMainLayer, self).__init__(**kwargs)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
# We allow three types of multi-inputs:
|
||||
# - traditional keyword arguments in the call method
|
||||
# - all the arguments provided as a dict in the first positional argument of call
|
||||
# - all the arguments provided as a list/tuple (ordered) in the first positional argument of call
|
||||
# The last two options are useful to use the tf.keras fit() method.
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if not head_mask is None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
##################################
|
||||
# Replace this with your model code
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
||||
sequence_output = encoder_outputs[0]
|
||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
|
||||
return outputs # sequence_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
####################################################
|
||||
# TFXxxPreTrainedModel is a sub-class of tf.keras.Model
|
||||
# which take care of loading and saving pretrained weights
|
||||
# and various common utilities.
|
||||
# Here you just need to specify a few (self-explanatory)
|
||||
# pointers for your model.
|
||||
####################################################
|
||||
class TFXxxPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
XXX_START_DOCSTRING = r""" The XXX model was proposed in
|
||||
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
||||
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
|
||||
pre-trained using a combination of masked language modeling objective and next sentence prediction
|
||||
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
||||
|
||||
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
|
||||
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
|
||||
https://arxiv.org/abs/1810.04805
|
||||
|
||||
.. _`tf.keras.Model`:
|
||||
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
|
||||
|
||||
Note on the model inputs:
|
||||
TF 2.0 models accepts two formats as inputs:
|
||||
|
||||
- having all inputs as keyword arguments (like PyTorch models), or
|
||||
- having all inputs as a list, tuple or dict in the first positional arguments.
|
||||
|
||||
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
|
||||
|
||||
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
|
||||
|
||||
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
|
||||
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
||||
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
||||
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
|
||||
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
XXX_INPUTS_DOCSTRING = r"""
|
||||
Inputs:
|
||||
**input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
To match pre-training, XXX input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
Xxx is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
|
||||
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**token_type_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
(see `XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
||||
**position_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
**head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class TFXxxModel(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``tf.Tensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Xxx pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XxxTokenizer, TFXxxModel
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = TFXxxModel.from_pretrained('xxx-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XxxTokenizer, TFXxxForMaskedLM
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = TFXxxForMaskedLM.from_pretrained('xxx-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
prediction_scores = outputs[0]
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XxxTokenizer, TFXxxForSequenceClassification
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = TFXxxForSequenceClassification.from_pretrained('xxx-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
logits = outputs[0]
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XxxTokenizer, TFXxxForTokenClassification
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = TFXxxForTokenClassification.from_pretrained('xxx-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
scores = outputs[0]
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**start_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XxxTokenizer, TFXxxForQuestionAnswering
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = TFXxxForQuestionAnswering.from_pretrained('xxx-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='qa_outputs')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
|
||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
644
templates/adding_a_new_model/modeling_xxx.py
Normal file
644
templates/adding_a_new_model/modeling_xxx.py
Normal file
@ -0,0 +1,644 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch XXX model. """
|
||||
|
||||
####################################################
|
||||
# In this template, replace all the XXX (various casings) with your model name
|
||||
####################################################
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .configuration_xxx import XxxConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||
}
|
||||
|
||||
####################################################
|
||||
# This is a conversion method from TF 1.0 to PyTorch
|
||||
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
|
||||
####################################################
|
||||
def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model.
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
except AttributeError:
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
####################################################
|
||||
# PyTorch Models are constructed by sub-classing
|
||||
# - torch.nn.Module for the layers and
|
||||
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
|
||||
####################################################
|
||||
|
||||
####################################################
|
||||
# Here is an example of typical layer in a PyTorch model of the library
|
||||
# The classes are usually identical to the TF 2.0 ones without the 'TF' prefix.
|
||||
#
|
||||
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
||||
####################################################
|
||||
class XxxLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(XxxLayer, self).__init__()
|
||||
self.attention = XxxAttention(config)
|
||||
self.intermediate = XxxIntermediate(config)
|
||||
self.output = XxxOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
####################################################
|
||||
# PreTrainedModel is a sub-class of torch.nn.Module
|
||||
# which take care of loading and saving pretrained weights
|
||||
# and various common utilities.
|
||||
#
|
||||
# Here you just need to specify a few (self-explanatory)
|
||||
# pointers for your model and the weights initialization
|
||||
# method if its not fully covered by PreTrainedModel's default method
|
||||
####################################################
|
||||
class XxxPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xxx
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, XxxLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
XXX_START_DOCSTRING = r""" The XXX model was proposed in
|
||||
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
||||
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
|
||||
pre-trained using a combination of masked language modeling objective and next sentence prediction
|
||||
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
|
||||
https://arxiv.org/abs/1810.04805
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
XXX_INPUTS_DOCSTRING = r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
To match pre-training, XXX input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
Xxx is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
|
||||
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
(see `XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class XxxModel(XxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Xxx pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = XxxModel.from_pretrained('xxx-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XxxModel, self).__init__(config)
|
||||
|
||||
self.embeddings = XxxEmbeddings(config)
|
||||
self.encoder = XxxEncoder(config)
|
||||
self.pooler = XxxPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
##################################
|
||||
# Replace this with your model code
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
|
||||
return outputs # sequence_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = XxxForMaskedLM.from_pretrained('xxx-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, masked_lm_labels=input_ids)
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XxxForMaskedLM, self).__init__(config)
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.cls = XxxOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.transformer.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
outputs = (masked_lm_loss,) + outputs
|
||||
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = XxxForSequenceClassification.from_pretrained('xxx-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XxxForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, labels=None):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = XxxForTokenClassification.from_pretrained('xxx-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XxxForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, labels=None):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
||||
model = XxxForQuestionAnswering.from_pretrained('xxx-large-uncased-whole-word-masking-finetuned-squad')
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
|
||||
input_ids = tokenizer.encode(input_text)
|
||||
token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
|
||||
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
|
||||
# a nice puppet
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XxxForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
start_positions=None, end_positions=None):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
256
templates/adding_a_new_model/tests/modeling_tf_xxx_test.py
Normal file
256
templates/adding_a_new_model/tests/modeling_tf_xxx_test.py
Normal file
@ -0,0 +1,256 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
from transformers import XxxConfig, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_xxx import (TFXxxModel, TFXxxForMaskedLM,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
TFXxxForQuestionAnswering,
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
|
||||
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFXxxModel, TFXxxForMaskedLM, TFXxxForQuestionAnswering,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification) if is_tf_available() else ()
|
||||
|
||||
class TFXxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = XxxConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFXxxModel(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
"pooled_output": pooled_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFXxxForMaskedLM(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForSequenceClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForTokenClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFXxxForQuestionAnswering(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXxxModelTest.TFXxxModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XxxConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_xxx_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_token_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
255
templates/adding_a_new_model/tests/modeling_xxx_test.py
Normal file
255
templates/adding_a_new_model/tests/modeling_xxx_test.py
Normal file
@ -0,0 +1,255 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
XxxForNextSentencePrediction, XxxForPreTraining,
|
||||
XxxForQuestionAnswering, XxxForSequenceClassification,
|
||||
XxxForTokenClassification, XxxForMultipleChoice)
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering,
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification) if is_torch_available() else ()
|
||||
|
||||
class XxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = XxxConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = XxxModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = XxxForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = XxxForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForTokenClassification(config=config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = XxxModelTest.XxxModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XxxConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_xxx_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xxx_for_token_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
57
templates/adding_a_new_model/tests/tokenization_xxx_test.py
Normal file
57
templates/adding_a_new_model/tests/tokenization_xxx_test.py
Normal file
@ -0,0 +1,57 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import (XxxTokenizer, VOCAB_FILES_NAMES)
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = XxxTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(XxxTokenizationTest, self).setUp()
|
||||
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ",", "low", "lowest",
|
||||
]
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
218
templates/adding_a_new_model/tokenization_xxx.py
Normal file
218
templates/adding_a_new_model/tokenization_xxx.py
Normal file
@ -0,0 +1,218 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization class for model XXX."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# In this template, replace all the XXX (various casings) with your model name
|
||||
####################################################
|
||||
|
||||
####################################################
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to file names for serializing Tokenizer instances
|
||||
####################################################
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
||||
|
||||
####################################################
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to pretrained vocabulary URL for all the model shortcut names.
|
||||
####################################################
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
####################################################
|
||||
# Mapping from model shortcut names to max length of inputs
|
||||
####################################################
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xxx-base-uncased': 512,
|
||||
'xxx-large-uncased': 512,
|
||||
}
|
||||
|
||||
####################################################
|
||||
# Mapping from model shortcut names to a dictionary of additional
|
||||
# keyword arguments for Tokenizer `__init__`.
|
||||
# To be used for checkpoint specific configurations.
|
||||
####################################################
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'xxx-base-uncased': {'do_lower_case': True},
|
||||
'xxx-large-uncased': {'do_lower_case': True},
|
||||
}
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip('\n')
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
|
||||
class XxxTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Constructs a XxxTokenizer.
|
||||
:class:`~transformers.XxxTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", **kwargs):
|
||||
"""Constructs a XxxTokenizer.
|
||||
|
||||
Args:
|
||||
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
|
||||
**do_lower_case**: (`optional`) boolean (default True)
|
||||
Whether to lower case the input
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
"""
|
||||
super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Take as input a string and return a list of strings (tokens) for words/sub-words
|
||||
"""
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A BERT sequence has the following format:
|
||||
single sequence: [CLS] X [SEP]
|
||||
pair of sequences: [CLS] A [SEP] B [SEP]
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0: list of ids (must not contain special tokens)
|
||||
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
||||
for sequence pairs
|
||||
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
||||
special tokens for the model
|
||||
|
||||
Returns:
|
||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError("You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model.")
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
A BERT sequence pair mask has the following format:
|
||||
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
else:
|
||||
vocab_file = vocab_path
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return (vocab_file,)
|
Loading…
Reference in New Issue
Block a user