diff --git a/examples/run_squad.py b/examples/run_squad.py index 0c0fbf29636..922a3230875 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet).""" +""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet) with an optional step of distillation.""" from __future__ import absolute_import, division, print_function @@ -28,6 +28,8 @@ import torch from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) from torch.utils.data.distributed import DistributedSampler +import torch.nn.functional as F +import torch.nn as nn from tqdm import tqdm, trange from tensorboardX import SummaryWriter @@ -73,7 +75,7 @@ def set_seed(args): def to_list(tensor): return tensor.detach().cpu().tolist() -def train(args, train_dataset, model, tokenizer): +def train(args, train_dataset, model, tokenizer, teacher=None): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() @@ -132,17 +134,40 @@ def train(args, train_dataset, model, tokenizer): epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): model.train() + if teacher is not None: + teacher.eval() batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': None if args.model_type == 'xlm' else batch[2], 'start_positions': batch[3], 'end_positions': batch[4]} + if args.model_type != 'distilbert': + inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] if args.model_type in ['xlnet', 'xlm']: inputs.update({'cls_index': batch[5], 'p_mask': batch[6]}) outputs = model(**inputs) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) + loss, start_logits_stu, end_logits_stu = outputs + + # Distillation loss + if teacher is not None: + if 'token_type_ids' not in inputs: + inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2] + with torch.no_grad(): + start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'], + token_type_ids=inputs['token_type_ids'], + attention_mask=inputs['attention_mask']) + assert start_logits_tea.size() == start_logits_stu.size() + assert end_logits_tea.size() == end_logits_stu.size() + + loss_fct = nn.KLDivLoss(reduction='batchmean') + loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1), + F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) + loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1), + F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) + loss_ce = (loss_start + loss_end)/2. + + loss = args.alpha_ce*loss_ce + args.alpha_squad*loss if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training @@ -218,9 +243,10 @@ def evaluate(args, model, tokenizer, prefix=""): batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = {'input_ids': batch[0], - 'attention_mask': batch[1], - 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids + 'attention_mask': batch[1] } + if args.model_type != 'distilbert': + inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids example_indices = batch[3] if args.model_type in ['xlnet', 'xlm']: inputs.update({'cls_index': batch[4], @@ -341,6 +367,18 @@ def main(): parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints and predictions will be written.") + # Distillation parameters (optional) + parser.add_argument('--teacher_type', default=None, type=str, + help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.") + parser.add_argument('--teacher_name_or_path', default=None, type=str, + help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.") + parser.add_argument('--alpha_ce', default=0.5, type=float, + help="Distillation loss linear weight. Only for distillation.") + parser.add_argument('--alpha_squad', default=0.5, type=float, + help="True SQuAD loss linear weight. Only for distillation.") + parser.add_argument('--temperature', default=2.0, type=float, + help="Distillation temperature. Only for distillation.") + ## Other parameters parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") @@ -468,6 +506,18 @@ def main(): tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) + if args.teacher_type is not None: + assert args.teacher_name_or_path is not None + assert args.alpha_ce > 0. + assert args.alpha_ce + args.alpha_squad > 0. + assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT." + teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type] + teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path) + teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, config=teacher_config) + teacher.to(args.device) + else: + teacher = None + if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab @@ -478,7 +528,7 @@ def main(): # Training if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) - global_step, tr_loss = train(args, train_dataset, model, tokenizer) + global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)