mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[distillation] big update w/ new weights
This commit is contained in:
parent
0d1dad6d53
commit
354944e607
@ -2,12 +2,21 @@
|
||||
|
||||
This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT.
|
||||
|
||||
**2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
|
||||
|
||||
## What is DistilBERT
|
||||
|
||||
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
|
||||
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
|
||||
|
||||
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
|
||||
).
|
||||
). *Please note that we will publish a formal write-up with updated and more complete results in the near future (September 19th).*
|
||||
|
||||
Here's the updated results on the dev sets of GLUE:
|
||||
|
||||
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | WNLI |
|
||||
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:|
|
||||
| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 |
|
||||
| DistilBERT | **75.2** | 49.1 | 81.8 | 90.2 | 87.0 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 |
|
||||
|
||||
## Setup
|
||||
|
||||
@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci
|
||||
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
|
||||
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
|
||||
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
|
||||
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
|
||||
|
||||
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
|
||||
|
||||
|
@ -92,11 +92,11 @@ class Dataset:
|
||||
Too short sequences are simply removed. This could be tunedd.
|
||||
"""
|
||||
init_size = len(self)
|
||||
indices = self.lengths > 5
|
||||
indices = self.lengths > 11
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.')
|
||||
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.')
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
|
@ -18,15 +18,18 @@
|
||||
import os
|
||||
import math
|
||||
import psutil
|
||||
import time
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import trange, tqdm
|
||||
import numpy as np
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import AdamW
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
from pytorch_transformers import WarmupLinearSchedule
|
||||
|
||||
from utils import logger
|
||||
from dataset import Dataset
|
||||
@ -58,10 +61,12 @@ class Distiller:
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
self.alpha_mse = params.alpha_mse
|
||||
self.alpha_cos = params.alpha_cos
|
||||
assert self.alpha_ce >= 0.
|
||||
assert self.alpha_mlm >= 0.
|
||||
assert self.alpha_mse >= 0.
|
||||
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.
|
||||
assert self.alpha_cos >= 0.
|
||||
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0.
|
||||
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
@ -81,17 +86,21 @@ class Distiller:
|
||||
self.last_loss = 0
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_mse = 0
|
||||
if self.alpha_mse > 0.: self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.: self.last_loss_cos = 0
|
||||
self.last_log = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||
if self.alpha_mse > 0.:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||
if self.alpha_cos > 0.:
|
||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
|
||||
|
||||
logger.info('--- Initializing model optimizer')
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
|
||||
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
@ -104,9 +113,12 @@ class Distiller:
|
||||
lr=params.learning_rate,
|
||||
eps=params.adam_epsilon,
|
||||
betas=(0.9, 0.98))
|
||||
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
logger.info(f'--- Scheduler: {params.scheduler_type}')
|
||||
self.scheduler = WarmupLinearSchedule(self.optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
t_total=num_train_optimization_steps)
|
||||
warmup_steps=warmup_steps,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
if self.fp16:
|
||||
try:
|
||||
@ -272,11 +284,14 @@ class Distiller:
|
||||
The real training loop.
|
||||
"""
|
||||
if self.is_master: logger.info('Starting training')
|
||||
self.last_log = time.time()
|
||||
self.student.train()
|
||||
self.teacher.eval()
|
||||
|
||||
for _ in range(self.params.n_epoch):
|
||||
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||
if self.multi_gpu:
|
||||
torch.distributed.barrier()
|
||||
|
||||
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for __ in range(self.num_steps_epoch):
|
||||
@ -314,9 +329,9 @@ class Distiller:
|
||||
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
|
||||
"""
|
||||
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
@ -340,6 +355,22 @@ class Distiller:
|
||||
if self.alpha_mse > 0.:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
|
||||
if self.alpha_cos > 0.:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||
assert s_hidden_states.size() == t_hidden_states.size()
|
||||
dim = s_hidden_states.size(-1)
|
||||
|
||||
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
|
||||
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
|
||||
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
|
||||
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
|
||||
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
|
||||
loss += self.alpha_cos * loss_cos
|
||||
|
||||
self.total_loss_epoch += loss.item()
|
||||
self.last_loss = loss.item()
|
||||
@ -348,6 +379,8 @@ class Distiller:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_mse > 0.:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
if self.alpha_cos > 0.:
|
||||
self.last_loss_cos = loss_cos.item()
|
||||
|
||||
self.optimize(loss)
|
||||
|
||||
@ -396,6 +429,7 @@ class Distiller:
|
||||
|
||||
if self.n_total_iter % self.params.log_interval == 0:
|
||||
self.log_tensorboard()
|
||||
self.last_log = time.time()
|
||||
if self.n_total_iter % self.params.checkpoint_interval == 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
@ -421,9 +455,12 @@ class Distiller:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
||||
if self.alpha_mse > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||
if self.alpha_cos > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||
|
||||
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
|
@ -2,3 +2,5 @@ gitpython==3.0.2
|
||||
tensorboard>=1.14.0
|
||||
tensorboardX==1.8
|
||||
psutil==5.6.3
|
||||
scipy==1.3.1
|
||||
pytorch_transformers==1.2.0
|
||||
|
@ -20,7 +20,7 @@ import pickle
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
from pytorch_transformers import BertTokenizer
|
||||
from pytorch_transformers import BertTokenizer, RobertaTokenizer
|
||||
import logging
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
@ -32,16 +32,21 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
|
||||
parser.add_argument('--file_path', type=str, default='data/dump.txt',
|
||||
help='The path to the data.')
|
||||
parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased',
|
||||
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta'])
|
||||
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
|
||||
help="The tokenizer to use.")
|
||||
parser.add_argument('--dump_file', type=str, default='data/dump',
|
||||
help='The dump file prefix.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logger.info(f'Loading Tokenizer ({args.bert_tokenizer})')
|
||||
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
|
||||
|
||||
logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
|
||||
if args.tokenizer_type == 'bert':
|
||||
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.tokenizer_type == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `<s>` for roberta
|
||||
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `</s>` for roberta
|
||||
|
||||
logger.info(f'Loading text from {args.file_path}')
|
||||
with open(args.file_path, 'r', encoding='utf8') as fp:
|
||||
@ -56,8 +61,8 @@ def main():
|
||||
interval = 10000
|
||||
start = time.time()
|
||||
for text in data:
|
||||
text = f'[CLS] {text.strip()} [SEP]'
|
||||
token_ids = bert_tokenizer.encode(text)
|
||||
text = f'{bos} {text.strip()} {sep}'
|
||||
token_ids = tokenizer.encode(text)
|
||||
rslt.append(token_ids)
|
||||
|
||||
iter += 1
|
||||
@ -69,7 +74,7 @@ def main():
|
||||
logger.info(f'{len(data)} examples processed.')
|
||||
|
||||
|
||||
dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle'
|
||||
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle'
|
||||
rslt_ = [np.uint16(d) for d in rslt]
|
||||
random.shuffle(rslt_)
|
||||
logger.info(f'Dump to {dp_file}')
|
||||
|
@ -15,59 +15,73 @@
|
||||
"""
|
||||
Preprocessing script before training DistilBERT.
|
||||
"""
|
||||
from pytorch_transformers import BertForPreTraining
|
||||
from pytorch_transformers import BertForMaskedLM, RobertaForMaskedLM
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation")
|
||||
parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
|
||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str)
|
||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
|
||||
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"])
|
||||
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
|
||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
|
||||
parser.add_argument("--vocab_transform", action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||
if args.model_type == 'bert':
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = 'bert'
|
||||
elif args.model_type == 'roberta':
|
||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = 'roberta'
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
for w in ['word_embeddings', 'position_embeddings']:
|
||||
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
|
||||
state_dict[f'bert.embeddings.{w}.weight']
|
||||
state_dict[f'{prefix}.embeddings.{w}.weight']
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
|
||||
state_dict[f'bert.embeddings.LayerNorm.{w}']
|
||||
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
|
||||
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
|
||||
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
|
||||
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
|
||||
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
||||
if args.vocab_transform:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
||||
if args.model_type == 'bert':
|
||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
||||
if args.vocab_transform:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
||||
elif args.model_type == 'roberta':
|
||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
|
||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
|
||||
if args.vocab_transform:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
|
||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
|
||||
|
||||
print(f'N layers selected for distillation: {std_idx}')
|
||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
||||
|
@ -23,7 +23,7 @@ import shutil
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import BertTokenizer, BertForMaskedLM
|
||||
from pytorch_transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
|
||||
from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
|
||||
|
||||
from distiller import Distiller
|
||||
@ -70,8 +70,10 @@ def main():
|
||||
help="Load student initialization checkpoint.")
|
||||
parser.add_argument("--from_pretrained_config", default=None, type=str,
|
||||
help="Load student initialization architecture config.")
|
||||
parser.add_argument("--bert_model", default='bert-base-uncased', type=str,
|
||||
help="The teacher BERT model.")
|
||||
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
|
||||
help="Teacher type (BERT, RoBERTa).")
|
||||
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
|
||||
help="The teacher model.")
|
||||
|
||||
parser.add_argument("--temperature", default=2., type=float,
|
||||
help="Temperature for the softmax temperature.")
|
||||
@ -81,6 +83,8 @@ def main():
|
||||
help="Linear weight for the MLM loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mse", default=0.0, type=float,
|
||||
help="Linear weight of the MSE loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_cos", default=0.0, type=float,
|
||||
help="Linear weight of the cosine embedding loss. Must be >=0.")
|
||||
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
|
||||
help="Proportion of tokens for which we need to make a prediction.")
|
||||
parser.add_argument("--word_mask", default=0.8, type=float,
|
||||
@ -165,11 +169,14 @@ def main():
|
||||
|
||||
|
||||
### TOKENIZER ###
|
||||
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||
if args.teacher_type == 'bert':
|
||||
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
|
||||
elif args.teacher_type == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
|
||||
special_tok_ids = {}
|
||||
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items():
|
||||
idx = bert_tokenizer.all_special_tokens.index(tok_symbol)
|
||||
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx]
|
||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||
logger.info(f'Special tokens {special_tok_ids}')
|
||||
args.special_tok_ids = special_tok_ids
|
||||
|
||||
@ -202,11 +209,12 @@ def main():
|
||||
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
|
||||
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
|
||||
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
|
||||
stu_architecture_config.output_hidden_states = True
|
||||
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
|
||||
config=stu_architecture_config)
|
||||
config=stu_architecture_config)
|
||||
else:
|
||||
args.vocab_size_or_config_json_file = args.vocab_size
|
||||
stu_architecture_config = DistilBertConfig(**vars(args))
|
||||
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
|
||||
student = DistilBertForMaskedLM(stu_architecture_config)
|
||||
|
||||
|
||||
@ -216,10 +224,13 @@ def main():
|
||||
|
||||
|
||||
## TEACHER ##
|
||||
teacher = BertForMaskedLM.from_pretrained(args.bert_model)
|
||||
if args.teacher_type == 'bert':
|
||||
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
elif args.teacher_type == 'roberta':
|
||||
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
if args.n_gpu > 0:
|
||||
teacher.to(f'cuda:{args.local_rank}')
|
||||
logger.info(f'Teacher loaded from {args.bert_model}.')
|
||||
logger.info(f'Teacher loaded from {args.teacher_name}.')
|
||||
|
||||
## DISTILLER ##
|
||||
torch.cuda.empty_cache()
|
||||
|
Loading…
Reference in New Issue
Block a user