mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add dataset. distiller, utils
This commit is contained in:
parent
5d29f8e99b
commit
1ae81e4aa1
184
examples/distillation/dataset.py
Normal file
184
examples/distillation/dataset.py
Normal file
@ -0,0 +1,184 @@
|
||||
from typing import List
|
||||
import math
|
||||
from itertools import chain
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils import logger
|
||||
|
||||
class Dataset:
|
||||
def __init__(self,
|
||||
params,
|
||||
data):
|
||||
self.params = params
|
||||
self.tokens_per_batch = params.tokens_per_batch
|
||||
self.batch_size = params.batch_size
|
||||
self.shuffle = params.shuffle
|
||||
self.group_by_size = params.group_by_size
|
||||
|
||||
self.token_ids = np.array(data)
|
||||
self.lengths = np.uint16([len(t) for t in data])
|
||||
|
||||
self.check()
|
||||
self.remove_long_sequences()
|
||||
self.remove_empty_sequences()
|
||||
self.check()
|
||||
self.print_statistics()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Some sanity checks
|
||||
"""
|
||||
assert len(self.token_ids) == len(self.lengths)
|
||||
|
||||
def remove_long_sequences(self):
|
||||
"""
|
||||
Sequences that are too long are splitted by chunk of max_position_embeddings.
|
||||
"""
|
||||
indices = self.lengths >= self.params.max_position_embeddings
|
||||
logger.info(f'Splitting {sum(indices)} too long sequences.')
|
||||
|
||||
def divide_chunks(l, n):
|
||||
return [l[i:i + n] for i in range(0, len(l), n)]
|
||||
|
||||
new_tok_ids = []
|
||||
new_lengths = []
|
||||
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
|
||||
max_len = self.params.max_position_embeddings
|
||||
|
||||
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||
if len_ <= max_len:
|
||||
new_tok_ids.append(seq_)
|
||||
new_lengths.append(len_)
|
||||
else:
|
||||
sub_seqs = []
|
||||
for sub_s in divide_chunks(seq_, max_len-2):
|
||||
if sub_s[0] != cls_id:
|
||||
sub_s = np.insert(sub_s, 0, cls_id)
|
||||
if sub_s[-1] != sep_id:
|
||||
sub_s = np.insert(sub_s, len(sub_s), cls_id)
|
||||
assert len(sub_s) <= max_len
|
||||
sub_seqs.append(sub_s)
|
||||
|
||||
new_tok_ids.extend(sub_seqs)
|
||||
new_lengths.extend([len(l) for l in sub_seqs])
|
||||
|
||||
self.token_ids = np.array(new_tok_ids)
|
||||
self.lengths = np.array(new_lengths)
|
||||
|
||||
def remove_empty_sequences(self):
|
||||
"""
|
||||
Too short sequences are simply removed. This could be tunedd.
|
||||
"""
|
||||
init_size = len(self)
|
||||
indices = self.lengths > 5
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.')
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
Print some statistics on the corpus. Only the master process.
|
||||
"""
|
||||
if not self.params.is_master:
|
||||
return
|
||||
logger.info(f'{len(self)} sequences')
|
||||
# data_len = sum(self.lengths)
|
||||
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||
|
||||
# unk_idx = self.params.special_tok_ids['unk_token']
|
||||
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
||||
|
||||
def select_data(self, a: int, b: int):
|
||||
"""
|
||||
Select a subportion of the data.
|
||||
"""
|
||||
n_sequences = len(self)
|
||||
assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')
|
||||
|
||||
logger.info(f'Selecting sequences from {a} to {b} (excluded).')
|
||||
self.token_ids = self.token_ids[a:b]
|
||||
self.lengths = self.lengths[a:b]
|
||||
|
||||
self.check()
|
||||
|
||||
def split(self):
|
||||
"""
|
||||
Distributed training: split the data accross the processes.
|
||||
"""
|
||||
assert self.params.n_gpu > 1
|
||||
logger.info('Splitting the data accross the processuses.')
|
||||
n_seq = len(self)
|
||||
n_seq_per_procesus = n_seq // self.params.world_size
|
||||
a = n_seq_per_procesus * self.params.global_rank
|
||||
b = a + n_seq_per_procesus
|
||||
self.select_data(a=a, b=b)
|
||||
|
||||
def batch_sequences(self,
|
||||
token_ids: List[List[int]],
|
||||
lengths: List[int]):
|
||||
"""
|
||||
Do the padding and transform into torch.tensor.
|
||||
"""
|
||||
assert len(token_ids) == len(lengths)
|
||||
|
||||
# Max for paddings
|
||||
max_seq_len_ = max(lengths)
|
||||
|
||||
# Pad token ids
|
||||
pad_idx = self.params.special_tok_ids['pad_token']
|
||||
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
|
||||
assert len(tk_) == len(token_ids)
|
||||
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||
|
||||
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||
lg_t = torch.tensor(lengths.astype(int)) # (bs)
|
||||
return tk_t, lg_t
|
||||
|
||||
def get_batches_iterator(self,
|
||||
batches):
|
||||
"""
|
||||
Return an iterator over batches.
|
||||
"""
|
||||
for sequences_ids in batches:
|
||||
token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids],
|
||||
self.lengths[sequences_ids])
|
||||
yield (token_ids, lengths)
|
||||
|
||||
def get_iterator(self,
|
||||
seed: int = None):
|
||||
"""
|
||||
Return a data iterator.
|
||||
"""
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
n_sequences = len(self)
|
||||
indices = np.arange(n_sequences)
|
||||
|
||||
if self.group_by_size:
|
||||
indices = indices[np.argsort(self.lengths[indices], kind='mergesort')]
|
||||
|
||||
if self.tokens_per_batch == -1:
|
||||
batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
|
||||
else:
|
||||
assert self.tokens_per_batch > 0
|
||||
batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch
|
||||
_, bounds = np.unique(batch_ids, return_index=True)
|
||||
batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
|
||||
if bounds[-1] < len(indices):
|
||||
batches.append(indices[bounds[-1]:])
|
||||
|
||||
if self.shuffle:
|
||||
rng.shuffle(batches)
|
||||
|
||||
assert n_sequences == sum([len(x) for x in batches])
|
||||
assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches])
|
||||
|
||||
return self.get_batches_iterator(batches=batches)
|
431
examples/distillation/distiller.py
Normal file
431
examples/distillation/distiller.py
Normal file
@ -0,0 +1,431 @@
|
||||
import os
|
||||
import math
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import trange, tqdm
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils import logger
|
||||
from dataset import Dataset
|
||||
|
||||
class Distiller:
|
||||
def __init__(self,
|
||||
params: dict,
|
||||
dataloader: Dataset,
|
||||
token_probs: torch.tensor,
|
||||
student: nn.Module,
|
||||
teacher: nn.Module):
|
||||
logger.info('Initializing Distiller')
|
||||
self.params = params
|
||||
self.dump_path = params.dump_path
|
||||
self.multi_gpu = params.multi_gpu
|
||||
self.fp16 = params.fp16
|
||||
|
||||
self.student = student
|
||||
self.teacher = teacher
|
||||
|
||||
self.dataloader = dataloader
|
||||
if self.params.n_gpu > 1:
|
||||
self.dataloader.split()
|
||||
self.get_iterator(seed=params.seed)
|
||||
|
||||
self.temperature = params.temperature
|
||||
assert self.temperature > 0.
|
||||
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
self.alpha_mse = params.alpha_mse
|
||||
assert self.alpha_ce >= 0.
|
||||
assert self.alpha_mlm >= 0.
|
||||
assert self.alpha_mse >= 0.
|
||||
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.
|
||||
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
self.n_total_iter = 0
|
||||
self.n_sequences_epoch = 0
|
||||
self.total_loss_epoch = 0
|
||||
self.last_loss = 0
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_mse = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||
|
||||
logger.info('--- Initializing model optimizer')
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
|
||||
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
|
||||
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
|
||||
]
|
||||
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
|
||||
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||
self.optimizer = AdamW(optimizer_grouped_parameters,
|
||||
lr=params.learning_rate,
|
||||
eps=params.adam_epsilon,
|
||||
betas=(0.9, 0.98))
|
||||
self.scheduler = WarmupLinearSchedule(self.optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
if self.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||
self.student, self.optimizer = amp.initialize(self.student,
|
||||
self.optimizer,
|
||||
opt_level=self.params.fp16_opt_level)
|
||||
self.teacher = self.teacher.half()
|
||||
|
||||
if self.multi_gpu:
|
||||
if self.fp16:
|
||||
from apex.parallel import DistributedDataParallel
|
||||
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student)
|
||||
else:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank)
|
||||
|
||||
self.is_master = params.is_master
|
||||
if self.is_master:
|
||||
logger.info('--- Initializing Tensorboard')
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
|
||||
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
|
||||
|
||||
def get_iterator(self,
|
||||
seed: int = None):
|
||||
"""
|
||||
Initialize the data iterator.
|
||||
Each process has its own data iterator (iterating on his own random portion of the dataset).
|
||||
|
||||
Input:
|
||||
------
|
||||
seed: `int` - The random seed.
|
||||
"""
|
||||
logger.info('--- Initializing Data Iterator')
|
||||
self.data_iterator = self.dataloader.get_iterator(seed=seed)
|
||||
|
||||
def get_batch(self):
|
||||
"""
|
||||
Call the data iterator to output a new batch.
|
||||
If the data iterator went through the whole dataset, create a new iterator.
|
||||
"""
|
||||
assert hasattr(self, 'data_iterator')
|
||||
try:
|
||||
x = next(self.data_iterator)
|
||||
except StopIteration:
|
||||
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
|
||||
self.data_iterator = self.dataloader.get_iterator()
|
||||
x = next(self.data_iterator)
|
||||
return x
|
||||
|
||||
def prepare_batch(self,
|
||||
batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
||||
|
||||
Input:
|
||||
------
|
||||
batch: `Tuple`
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
|
||||
"""
|
||||
token_ids, lengths = batch
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
||||
|
||||
bs, max_seq_len = token_ids.size()
|
||||
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
|
||||
x_prob = self.token_probs[token_ids.flatten()]
|
||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device)
|
||||
pred_mask[tgt_ids] = 1
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
|
||||
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0
|
||||
|
||||
# mask a number of words == 0 [8] (faster with fp16)
|
||||
if self.fp16:
|
||||
n1 = pred_mask.sum().item()
|
||||
if n1 > 8:
|
||||
pred_mask = pred_mask.view(-1)
|
||||
n2 = max(n1 % 8, 8 * (n1 // 8))
|
||||
if n2 != n1:
|
||||
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||
|
||||
_token_ids_real = token_ids[pred_mask]
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size)
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
|
||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||
|
||||
mlm_labels[1-pred_mask] = -1
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
def round_batch(self,
|
||||
x: torch.tensor,
|
||||
lengths: torch.tensor):
|
||||
"""
|
||||
For float16 only.
|
||||
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||
|
||||
Input:
|
||||
------
|
||||
x: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
|
||||
|
||||
Output:
|
||||
-------
|
||||
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
|
||||
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
|
||||
"""
|
||||
if not self.fp16 or len(lengths) < 8:
|
||||
return x, lengths
|
||||
|
||||
# number of sentences == 0 [8]
|
||||
bs1 = len(lengths)
|
||||
bs2 = 8 * (bs1 // 8)
|
||||
assert bs2 > 0 and bs2 % 8 == 0
|
||||
if bs1 != bs2:
|
||||
idx = torch.randperm(bs1)[:bs2]
|
||||
lengths = lengths[idx]
|
||||
slen = lengths.max().item()
|
||||
x = x[idx, :slen]
|
||||
else:
|
||||
idx = None
|
||||
|
||||
# sequence length == 0 [8]
|
||||
ml1 = x.size(1)
|
||||
if ml1 % 8 != 0:
|
||||
pad = 8 - (ml1 % 8)
|
||||
ml2 = ml1 + pad
|
||||
pad_id = self.params.special_tok_ids['pad_token']
|
||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||
x = torch.cat([x, padding_tensor], 1)
|
||||
assert x.size() == (bs2, ml2)
|
||||
|
||||
assert x.size(0) % 8 == 0
|
||||
assert x.size(1) % 8 == 0
|
||||
return x, lengths
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
The real training loop.
|
||||
"""
|
||||
if self.is_master: logger.info('Starting training')
|
||||
self.student.train()
|
||||
self.teacher.eval()
|
||||
|
||||
for _ in range(self.params.n_epoch):
|
||||
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||
|
||||
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for __ in range(self.num_steps_epoch):
|
||||
batch = self.get_batch()
|
||||
if self.params.n_gpu > 0:
|
||||
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
|
||||
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
|
||||
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels)
|
||||
|
||||
iter_bar.update()
|
||||
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
|
||||
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
|
||||
iter_bar.close()
|
||||
|
||||
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||
self.end_epoch()
|
||||
|
||||
if self.is_master: logger.info('Training is finished')
|
||||
|
||||
def step(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor,
|
||||
mlm_labels: torch.tensor):
|
||||
"""
|
||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||
and possibly a parameter update (depending on the gradient accumulation).
|
||||
|
||||
Input:
|
||||
------
|
||||
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
|
||||
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
|
||||
"""
|
||||
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
|
||||
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
|
||||
loss = self.alpha_ce*loss_ce
|
||||
if self.alpha_mlm > 0.:
|
||||
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))
|
||||
loss += self.alpha_mlm * loss_mlm
|
||||
if self.alpha_mse > 0.:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
|
||||
self.total_loss_epoch += loss.item()
|
||||
self.last_loss = loss.item()
|
||||
self.last_loss_ce = loss_ce.item()
|
||||
if self.alpha_mlm > 0.:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_mse > 0.:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
|
||||
self.optimize(loss)
|
||||
|
||||
self.n_sequences_epoch += input_ids.size(0)
|
||||
|
||||
def optimize(self,
|
||||
loss):
|
||||
"""
|
||||
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||
Also update the metrics for tensorboard.
|
||||
"""
|
||||
# Check for NaN
|
||||
if (loss != loss).data.any():
|
||||
logger.error('NaN detected')
|
||||
exit()
|
||||
|
||||
if self.multi_gpu:
|
||||
loss = loss.mean()
|
||||
if self.params.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.params.gradient_accumulation_steps
|
||||
|
||||
if self.fp16:
|
||||
from apex import amp
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.iter()
|
||||
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
||||
if self.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||
self.scheduler.step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def iter(self):
|
||||
"""
|
||||
Update global counts, write to tensorboard and save checkpoint.
|
||||
"""
|
||||
self.n_iter += 1
|
||||
self.n_total_iter += 1
|
||||
|
||||
if self.n_total_iter % self.params.log_interval == 0:
|
||||
self.log_tensorboard()
|
||||
if self.n_total_iter % self.params.checkpoint_interval == 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
def log_tensorboard(self):
|
||||
"""
|
||||
Log into tensorboard. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
for param_name, param in self.student.named_parameters():
|
||||
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)
|
||||
|
||||
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
|
||||
if self.alpha_mlm > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
||||
if self.alpha_mse > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
Finally arrived at the end of epoch (full pass on dataset).
|
||||
Do some tensorboard logging and checkpoint saving.
|
||||
"""
|
||||
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')
|
||||
|
||||
if self.is_master:
|
||||
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
|
||||
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)
|
||||
|
||||
self.epoch += 1
|
||||
self.n_sequences_epoch = 0
|
||||
self.n_iter = 0
|
||||
self.total_loss_epoch = 0
|
||||
|
||||
def save_checkpoint(self,
|
||||
checkpoint_name: str = 'checkpoint.pth'):
|
||||
"""
|
||||
Save the current state. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
|
||||
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||
state_dict = mdl_to_save.state_dict()
|
||||
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
112
examples/distillation/utils.py
Normal file
112
examples/distillation/utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
import git
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def git_log(folder_path: str):
|
||||
"""
|
||||
Log commit info.
|
||||
"""
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
'repo_id': str(repo),
|
||||
'repo_sha': str(repo.head.object.hexsha),
|
||||
'repo_branch': str(repo.active_branch)
|
||||
}
|
||||
|
||||
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
def init_gpu_params(params):
|
||||
"""
|
||||
Handle single and multi-GPU / multi-node.
|
||||
"""
|
||||
if params.n_gpu <= 0:
|
||||
params.local_rank = 0
|
||||
params.master_port = -1
|
||||
params.is_master = True
|
||||
params.multi_gpu = False
|
||||
return
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
logger.info('Initializing GPUs')
|
||||
if params.n_gpu > 1:
|
||||
assert params.local_rank != -1
|
||||
|
||||
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
|
||||
params.global_rank = int(os.environ['RANK'])
|
||||
|
||||
# number of nodes / node ID
|
||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||
params.multi_gpu = True
|
||||
|
||||
assert params.n_nodes == int(os.environ['N_NODES'])
|
||||
assert params.node_id == int(os.environ['NODE_RANK'])
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
assert params.local_rank == -1
|
||||
|
||||
params.n_nodes = 1
|
||||
params.node_id = 0
|
||||
params.local_rank = 0
|
||||
params.global_rank = 0
|
||||
params.world_size = 1
|
||||
params.n_gpu_per_node = 1
|
||||
params.multi_gpu = False
|
||||
|
||||
# sanity checks
|
||||
assert params.n_nodes >= 1
|
||||
assert 0 <= params.node_id < params.n_nodes
|
||||
assert 0 <= params.local_rank <= params.global_rank < params.world_size
|
||||
assert params.world_size == params.n_nodes * params.n_gpu_per_node
|
||||
|
||||
# define whether this is the master process / if we are in multi-node distributed mode
|
||||
params.is_master = params.node_id == 0 and params.local_rank == 0
|
||||
params.multi_node = params.n_nodes > 1
|
||||
|
||||
# summary
|
||||
PREFIX = f"--- Global rank: {params.global_rank} - "
|
||||
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
|
||||
logger.info(PREFIX + "Node ID : %i" % params.node_id)
|
||||
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
|
||||
logger.info(PREFIX + "World size : %i" % params.world_size)
|
||||
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
|
||||
logger.info(PREFIX + "Master : %s" % str(params.is_master))
|
||||
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
|
||||
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
|
||||
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
|
||||
|
||||
# set GPU device
|
||||
torch.cuda.set_device(params.local_rank)
|
||||
|
||||
# initialize multi-GPU
|
||||
if params.multi_gpu:
|
||||
logger.info("Initializing PyTorch distributed")
|
||||
torch.distributed.init_process_group(
|
||||
init_method='env://',
|
||||
backend='nccl',
|
||||
)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
"""
|
||||
Set the random seed.
|
||||
"""
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
Loading…
Reference in New Issue
Block a user