mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update distiller
This commit is contained in:
parent
a12ab0a8db
commit
bb9c5ead54
@ -12,8 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" The distiller to distil DistilBERT
|
||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
""" The distiller to distil the student.
|
||||
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
@ -28,16 +28,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import RandomSampler, BatchSampler, DataLoader
|
||||
|
||||
from transformers import WarmupLinearSchedule
|
||||
|
||||
from utils import logger
|
||||
from dataset import Dataset
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
|
||||
class Distiller:
|
||||
def __init__(self,
|
||||
params: dict,
|
||||
dataloader: Dataset,
|
||||
dataset: LmSeqsDataset,
|
||||
token_probs: torch.tensor,
|
||||
student: nn.Module,
|
||||
teacher: nn.Module):
|
||||
@ -50,33 +53,47 @@ class Distiller:
|
||||
self.student = student
|
||||
self.teacher = teacher
|
||||
|
||||
self.dataloader = dataloader
|
||||
if self.params.n_gpu > 1:
|
||||
self.dataloader.split()
|
||||
self.get_iterator(seed=params.seed)
|
||||
self.student_config = student.config
|
||||
self.vocab_size = student.config.vocab_size
|
||||
|
||||
if params.n_gpu <= 1:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = DistributedSampler(dataset)
|
||||
|
||||
if params.group_by_size:
|
||||
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
|
||||
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
|
||||
else:
|
||||
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
||||
|
||||
self.dataloader = DataLoader(dataset=dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.batch_sequences)
|
||||
|
||||
self.temperature = params.temperature
|
||||
assert self.temperature > 0.
|
||||
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
self.alpha_clm = params.alpha_clm
|
||||
self.alpha_mse = params.alpha_mse
|
||||
self.alpha_cos = params.alpha_cos
|
||||
assert self.alpha_ce >= 0.
|
||||
assert self.alpha_mlm >= 0.
|
||||
assert self.alpha_mse >= 0.
|
||||
assert self.alpha_cos >= 0.
|
||||
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0.
|
||||
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
self.mlm = params.mlm
|
||||
if self.mlm:
|
||||
logger.info(f'Using MLM loss for LM step.')
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
else:
|
||||
logger.info(f'Using CLM loss for LM step.')
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
@ -86,12 +103,13 @@ class Distiller:
|
||||
self.last_loss = 0
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_clm = 0
|
||||
if self.alpha_mse > 0.: self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.: self.last_loss_cos = 0
|
||||
self.last_log = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
if self.alpha_mse > 0.:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||
if self.alpha_cos > 0.:
|
||||
@ -99,7 +117,7 @@ class Distiller:
|
||||
|
||||
logger.info('--- Initializing model optimizer')
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
|
||||
self.num_steps_epoch = len(self.dataloader)
|
||||
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
@ -140,43 +158,18 @@ class Distiller:
|
||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank)
|
||||
output_device=params.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
self.is_master = params.is_master
|
||||
if self.is_master:
|
||||
logger.info('--- Initializing Tensorboard')
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
|
||||
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
|
||||
|
||||
def get_iterator(self,
|
||||
seed: int = None):
|
||||
"""
|
||||
Initialize the data iterator.
|
||||
Each process has its own data iterator (iterating on his own random portion of the dataset).
|
||||
|
||||
Input:
|
||||
------
|
||||
seed: `int` - The random seed.
|
||||
"""
|
||||
logger.info('--- Initializing Data Iterator')
|
||||
self.data_iterator = self.dataloader.get_iterator(seed=seed)
|
||||
|
||||
def get_batch(self):
|
||||
"""
|
||||
Call the data iterator to output a new batch.
|
||||
If the data iterator went through the whole dataset, create a new iterator.
|
||||
"""
|
||||
assert hasattr(self, 'data_iterator')
|
||||
try:
|
||||
x = next(self.data_iterator)
|
||||
except StopIteration:
|
||||
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
|
||||
self.data_iterator = self.dataloader.get_iterator()
|
||||
x = next(self.data_iterator)
|
||||
return x
|
||||
|
||||
def prepare_batch(self,
|
||||
batch):
|
||||
def prepare_batch_mlm(self,
|
||||
batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
||||
|
||||
@ -222,7 +215,7 @@ class Distiller:
|
||||
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||
|
||||
_token_ids_real = token_ids[pred_mask]
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size)
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
|
||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||
@ -230,8 +223,41 @@ class Distiller:
|
||||
|
||||
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
def prepare_batch_clm(self,
|
||||
batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
||||
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, clm_labels
|
||||
|
||||
def round_batch(self,
|
||||
x: torch.tensor,
|
||||
lengths: torch.tensor):
|
||||
@ -269,7 +295,10 @@ class Distiller:
|
||||
if ml1 % 8 != 0:
|
||||
pad = 8 - (ml1 % 8)
|
||||
ml2 = ml1 + pad
|
||||
pad_id = self.params.special_tok_ids['pad_token']
|
||||
if self.mlm:
|
||||
pad_id = self.params.special_tok_ids['pad_token']
|
||||
else:
|
||||
pad_id = self.params.special_tok_ids['unk_token']
|
||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||
x = torch.cat([x, padding_tensor], 1)
|
||||
assert x.size() == (bs2, ml2)
|
||||
@ -292,14 +321,16 @@ class Distiller:
|
||||
if self.multi_gpu:
|
||||
torch.distributed.barrier()
|
||||
|
||||
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for __ in range(self.num_steps_epoch):
|
||||
batch = self.get_batch()
|
||||
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for batch in iter_bar:
|
||||
if self.params.n_gpu > 0:
|
||||
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
|
||||
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
|
||||
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels)
|
||||
if self.mlm:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
||||
else:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
||||
|
||||
iter_bar.update()
|
||||
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
|
||||
@ -317,7 +348,7 @@ class Distiller:
|
||||
def step(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor,
|
||||
mlm_labels: torch.tensor):
|
||||
lm_labels: torch.tensor):
|
||||
"""
|
||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||
and possibly a parameter update (depending on the gradient accumulation).
|
||||
@ -326,17 +357,22 @@ class Distiller:
|
||||
------
|
||||
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
|
||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||
"""
|
||||
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
if self.mlm:
|
||||
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
@ -348,13 +384,20 @@ class Distiller:
|
||||
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
|
||||
loss = self.alpha_ce*loss_ce
|
||||
|
||||
if self.alpha_mlm > 0.:
|
||||
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))
|
||||
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
||||
loss += self.alpha_mlm * loss_mlm
|
||||
if self.alpha_clm > 0.:
|
||||
shift_logits = s_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
loss += self.alpha_clm * loss_clm
|
||||
|
||||
if self.alpha_mse > 0.:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
|
||||
if self.alpha_cos > 0.:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
@ -376,6 +419,8 @@ class Distiller:
|
||||
self.last_loss_ce = loss_ce.item()
|
||||
if self.alpha_mlm > 0.:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_clm > 0.:
|
||||
self.last_loss_clm = loss_clm.item()
|
||||
if self.alpha_mse > 0.:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
if self.alpha_cos > 0.:
|
||||
@ -452,6 +497,8 @@ class Distiller:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
|
||||
if self.alpha_mlm > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
||||
if self.alpha_clm > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
|
||||
if self.alpha_mse > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||
if self.alpha_cos > 0.:
|
||||
|
Loading…
Reference in New Issue
Block a user