mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 01:32:23 +06:00
add distillation+finetuning option in run_squad
This commit is contained in:
parent
bb464289ce
commit
764a7923ec
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
|
||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet) with an optional step of distillation."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@ -28,6 +28,8 @@ import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
@ -73,7 +75,7 @@ def set_seed(args):
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
@ -132,17 +134,40 @@ def train(args, train_dataset, model, tokenizer):
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
if teacher is not None:
|
||||
teacher.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
loss, start_logits_stu, end_logits_stu = outputs
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if 'token_type_ids' not in inputs:
|
||||
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2]
|
||||
with torch.no_grad():
|
||||
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'],
|
||||
token_type_ids=inputs['token_type_ids'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
assert start_logits_tea.size() == start_logits_stu.size()
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1),
|
||||
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
||||
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1),
|
||||
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end)/2.
|
||||
|
||||
loss = args.alpha_ce*loss_ce + args.alpha_squad*loss
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
@ -218,9 +243,10 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
@ -341,6 +367,18 @@ def main():
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument('--teacher_type', default=None, type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.")
|
||||
parser.add_argument('--teacher_name_or_path', default=None, type=str,
|
||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.")
|
||||
parser.add_argument('--alpha_ce', default=0.5, type=float,
|
||||
help="Distillation loss linear weight. Only for distillation.")
|
||||
parser.add_argument('--alpha_squad', default=0.5, type=float,
|
||||
help="True SQuAD loss linear weight. Only for distillation.")
|
||||
parser.add_argument('--temperature', default=2.0, type=float,
|
||||
help="Distillation temperature. Only for distillation.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
@ -468,6 +506,18 @@ def main():
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_ce > 0.
|
||||
assert args.alpha_ce + args.alpha_squad > 0.
|
||||
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT."
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, config=teacher_config)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
@ -478,7 +528,7 @@ def main():
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user