mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
add dataset. distiller, utils
This commit is contained in:
parent
5d29f8e99b
commit
1ae81e4aa1
184
examples/distillation/dataset.py
Normal file
184
examples/distillation/dataset.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
from typing import List
|
||||||
|
import math
|
||||||
|
from itertools import chain
|
||||||
|
from collections import Counter
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from utils import logger
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
def __init__(self,
|
||||||
|
params,
|
||||||
|
data):
|
||||||
|
self.params = params
|
||||||
|
self.tokens_per_batch = params.tokens_per_batch
|
||||||
|
self.batch_size = params.batch_size
|
||||||
|
self.shuffle = params.shuffle
|
||||||
|
self.group_by_size = params.group_by_size
|
||||||
|
|
||||||
|
self.token_ids = np.array(data)
|
||||||
|
self.lengths = np.uint16([len(t) for t in data])
|
||||||
|
|
||||||
|
self.check()
|
||||||
|
self.remove_long_sequences()
|
||||||
|
self.remove_empty_sequences()
|
||||||
|
self.check()
|
||||||
|
self.print_statistics()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.lengths)
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
"""
|
||||||
|
Some sanity checks
|
||||||
|
"""
|
||||||
|
assert len(self.token_ids) == len(self.lengths)
|
||||||
|
|
||||||
|
def remove_long_sequences(self):
|
||||||
|
"""
|
||||||
|
Sequences that are too long are splitted by chunk of max_position_embeddings.
|
||||||
|
"""
|
||||||
|
indices = self.lengths >= self.params.max_position_embeddings
|
||||||
|
logger.info(f'Splitting {sum(indices)} too long sequences.')
|
||||||
|
|
||||||
|
def divide_chunks(l, n):
|
||||||
|
return [l[i:i + n] for i in range(0, len(l), n)]
|
||||||
|
|
||||||
|
new_tok_ids = []
|
||||||
|
new_lengths = []
|
||||||
|
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
|
||||||
|
max_len = self.params.max_position_embeddings
|
||||||
|
|
||||||
|
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||||
|
if len_ <= max_len:
|
||||||
|
new_tok_ids.append(seq_)
|
||||||
|
new_lengths.append(len_)
|
||||||
|
else:
|
||||||
|
sub_seqs = []
|
||||||
|
for sub_s in divide_chunks(seq_, max_len-2):
|
||||||
|
if sub_s[0] != cls_id:
|
||||||
|
sub_s = np.insert(sub_s, 0, cls_id)
|
||||||
|
if sub_s[-1] != sep_id:
|
||||||
|
sub_s = np.insert(sub_s, len(sub_s), cls_id)
|
||||||
|
assert len(sub_s) <= max_len
|
||||||
|
sub_seqs.append(sub_s)
|
||||||
|
|
||||||
|
new_tok_ids.extend(sub_seqs)
|
||||||
|
new_lengths.extend([len(l) for l in sub_seqs])
|
||||||
|
|
||||||
|
self.token_ids = np.array(new_tok_ids)
|
||||||
|
self.lengths = np.array(new_lengths)
|
||||||
|
|
||||||
|
def remove_empty_sequences(self):
|
||||||
|
"""
|
||||||
|
Too short sequences are simply removed. This could be tunedd.
|
||||||
|
"""
|
||||||
|
init_size = len(self)
|
||||||
|
indices = self.lengths > 5
|
||||||
|
self.token_ids = self.token_ids[indices]
|
||||||
|
self.lengths = self.lengths[indices]
|
||||||
|
new_size = len(self)
|
||||||
|
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.')
|
||||||
|
|
||||||
|
def print_statistics(self):
|
||||||
|
"""
|
||||||
|
Print some statistics on the corpus. Only the master process.
|
||||||
|
"""
|
||||||
|
if not self.params.is_master:
|
||||||
|
return
|
||||||
|
logger.info(f'{len(self)} sequences')
|
||||||
|
# data_len = sum(self.lengths)
|
||||||
|
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||||
|
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||||
|
|
||||||
|
# unk_idx = self.params.special_tok_ids['unk_token']
|
||||||
|
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||||
|
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
||||||
|
|
||||||
|
def select_data(self, a: int, b: int):
|
||||||
|
"""
|
||||||
|
Select a subportion of the data.
|
||||||
|
"""
|
||||||
|
n_sequences = len(self)
|
||||||
|
assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')
|
||||||
|
|
||||||
|
logger.info(f'Selecting sequences from {a} to {b} (excluded).')
|
||||||
|
self.token_ids = self.token_ids[a:b]
|
||||||
|
self.lengths = self.lengths[a:b]
|
||||||
|
|
||||||
|
self.check()
|
||||||
|
|
||||||
|
def split(self):
|
||||||
|
"""
|
||||||
|
Distributed training: split the data accross the processes.
|
||||||
|
"""
|
||||||
|
assert self.params.n_gpu > 1
|
||||||
|
logger.info('Splitting the data accross the processuses.')
|
||||||
|
n_seq = len(self)
|
||||||
|
n_seq_per_procesus = n_seq // self.params.world_size
|
||||||
|
a = n_seq_per_procesus * self.params.global_rank
|
||||||
|
b = a + n_seq_per_procesus
|
||||||
|
self.select_data(a=a, b=b)
|
||||||
|
|
||||||
|
def batch_sequences(self,
|
||||||
|
token_ids: List[List[int]],
|
||||||
|
lengths: List[int]):
|
||||||
|
"""
|
||||||
|
Do the padding and transform into torch.tensor.
|
||||||
|
"""
|
||||||
|
assert len(token_ids) == len(lengths)
|
||||||
|
|
||||||
|
# Max for paddings
|
||||||
|
max_seq_len_ = max(lengths)
|
||||||
|
|
||||||
|
# Pad token ids
|
||||||
|
pad_idx = self.params.special_tok_ids['pad_token']
|
||||||
|
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
|
||||||
|
assert len(tk_) == len(token_ids)
|
||||||
|
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||||
|
|
||||||
|
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||||
|
lg_t = torch.tensor(lengths.astype(int)) # (bs)
|
||||||
|
return tk_t, lg_t
|
||||||
|
|
||||||
|
def get_batches_iterator(self,
|
||||||
|
batches):
|
||||||
|
"""
|
||||||
|
Return an iterator over batches.
|
||||||
|
"""
|
||||||
|
for sequences_ids in batches:
|
||||||
|
token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids],
|
||||||
|
self.lengths[sequences_ids])
|
||||||
|
yield (token_ids, lengths)
|
||||||
|
|
||||||
|
def get_iterator(self,
|
||||||
|
seed: int = None):
|
||||||
|
"""
|
||||||
|
Return a data iterator.
|
||||||
|
"""
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
|
||||||
|
n_sequences = len(self)
|
||||||
|
indices = np.arange(n_sequences)
|
||||||
|
|
||||||
|
if self.group_by_size:
|
||||||
|
indices = indices[np.argsort(self.lengths[indices], kind='mergesort')]
|
||||||
|
|
||||||
|
if self.tokens_per_batch == -1:
|
||||||
|
batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
|
||||||
|
else:
|
||||||
|
assert self.tokens_per_batch > 0
|
||||||
|
batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch
|
||||||
|
_, bounds = np.unique(batch_ids, return_index=True)
|
||||||
|
batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
|
||||||
|
if bounds[-1] < len(indices):
|
||||||
|
batches.append(indices[bounds[-1]:])
|
||||||
|
|
||||||
|
if self.shuffle:
|
||||||
|
rng.shuffle(batches)
|
||||||
|
|
||||||
|
assert n_sequences == sum([len(x) for x in batches])
|
||||||
|
assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches])
|
||||||
|
|
||||||
|
return self.get_batches_iterator(batches=batches)
|
431
examples/distillation/distiller.py
Normal file
431
examples/distillation/distiller.py
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from tqdm import trange, tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
|
from utils import logger
|
||||||
|
from dataset import Dataset
|
||||||
|
|
||||||
|
class Distiller:
|
||||||
|
def __init__(self,
|
||||||
|
params: dict,
|
||||||
|
dataloader: Dataset,
|
||||||
|
token_probs: torch.tensor,
|
||||||
|
student: nn.Module,
|
||||||
|
teacher: nn.Module):
|
||||||
|
logger.info('Initializing Distiller')
|
||||||
|
self.params = params
|
||||||
|
self.dump_path = params.dump_path
|
||||||
|
self.multi_gpu = params.multi_gpu
|
||||||
|
self.fp16 = params.fp16
|
||||||
|
|
||||||
|
self.student = student
|
||||||
|
self.teacher = teacher
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
if self.params.n_gpu > 1:
|
||||||
|
self.dataloader.split()
|
||||||
|
self.get_iterator(seed=params.seed)
|
||||||
|
|
||||||
|
self.temperature = params.temperature
|
||||||
|
assert self.temperature > 0.
|
||||||
|
|
||||||
|
self.alpha_ce = params.alpha_ce
|
||||||
|
self.alpha_mlm = params.alpha_mlm
|
||||||
|
self.alpha_mse = params.alpha_mse
|
||||||
|
assert self.alpha_ce >= 0.
|
||||||
|
assert self.alpha_mlm >= 0.
|
||||||
|
assert self.alpha_mse >= 0.
|
||||||
|
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.
|
||||||
|
|
||||||
|
self.mlm_mask_prop = params.mlm_mask_prop
|
||||||
|
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||||
|
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||||
|
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||||
|
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
||||||
|
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
||||||
|
if self.fp16:
|
||||||
|
self.pred_probs = self.pred_probs.half()
|
||||||
|
self.token_probs = self.token_probs.half()
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
self.n_iter = 0
|
||||||
|
self.n_total_iter = 0
|
||||||
|
self.n_sequences_epoch = 0
|
||||||
|
self.total_loss_epoch = 0
|
||||||
|
self.last_loss = 0
|
||||||
|
self.last_loss_ce = 0
|
||||||
|
self.last_loss_mlm = 0
|
||||||
|
self.last_loss_mse = 0
|
||||||
|
|
||||||
|
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||||
|
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
|
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||||
|
|
||||||
|
logger.info('--- Initializing model optimizer')
|
||||||
|
assert params.gradient_accumulation_steps >= 1
|
||||||
|
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
|
||||||
|
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||||
|
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||||
|
|
||||||
|
no_decay = ['bias', 'LayerNorm.weight']
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
|
||||||
|
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
|
||||||
|
]
|
||||||
|
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
|
||||||
|
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||||
|
self.optimizer = AdamW(optimizer_grouped_parameters,
|
||||||
|
lr=params.learning_rate,
|
||||||
|
eps=params.adam_epsilon,
|
||||||
|
betas=(0.9, 0.98))
|
||||||
|
self.scheduler = WarmupLinearSchedule(self.optimizer,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
|
if self.fp16:
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||||
|
self.student, self.optimizer = amp.initialize(self.student,
|
||||||
|
self.optimizer,
|
||||||
|
opt_level=self.params.fp16_opt_level)
|
||||||
|
self.teacher = self.teacher.half()
|
||||||
|
|
||||||
|
if self.multi_gpu:
|
||||||
|
if self.fp16:
|
||||||
|
from apex.parallel import DistributedDataParallel
|
||||||
|
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||||
|
self.student = DistributedDataParallel(self.student)
|
||||||
|
else:
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||||
|
self.student = DistributedDataParallel(self.student,
|
||||||
|
device_ids=[params.local_rank],
|
||||||
|
output_device=params.local_rank)
|
||||||
|
|
||||||
|
self.is_master = params.is_master
|
||||||
|
if self.is_master:
|
||||||
|
logger.info('--- Initializing Tensorboard')
|
||||||
|
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
|
||||||
|
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
|
||||||
|
|
||||||
|
def get_iterator(self,
|
||||||
|
seed: int = None):
|
||||||
|
"""
|
||||||
|
Initialize the data iterator.
|
||||||
|
Each process has its own data iterator (iterating on his own random portion of the dataset).
|
||||||
|
|
||||||
|
Input:
|
||||||
|
------
|
||||||
|
seed: `int` - The random seed.
|
||||||
|
"""
|
||||||
|
logger.info('--- Initializing Data Iterator')
|
||||||
|
self.data_iterator = self.dataloader.get_iterator(seed=seed)
|
||||||
|
|
||||||
|
def get_batch(self):
|
||||||
|
"""
|
||||||
|
Call the data iterator to output a new batch.
|
||||||
|
If the data iterator went through the whole dataset, create a new iterator.
|
||||||
|
"""
|
||||||
|
assert hasattr(self, 'data_iterator')
|
||||||
|
try:
|
||||||
|
x = next(self.data_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
|
||||||
|
self.data_iterator = self.dataloader.get_iterator()
|
||||||
|
x = next(self.data_iterator)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def prepare_batch(self,
|
||||||
|
batch):
|
||||||
|
"""
|
||||||
|
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
------
|
||||||
|
batch: `Tuple`
|
||||||
|
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
|
||||||
|
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
-------
|
||||||
|
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
|
||||||
|
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
|
||||||
|
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
|
||||||
|
"""
|
||||||
|
token_ids, lengths = batch
|
||||||
|
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||||
|
assert token_ids.size(0) == lengths.size(0)
|
||||||
|
|
||||||
|
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
||||||
|
|
||||||
|
bs, max_seq_len = token_ids.size()
|
||||||
|
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||||
|
|
||||||
|
x_prob = self.token_probs[token_ids.flatten()]
|
||||||
|
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||||
|
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||||
|
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device)
|
||||||
|
pred_mask[tgt_ids] = 1
|
||||||
|
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||||
|
|
||||||
|
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0
|
||||||
|
|
||||||
|
# mask a number of words == 0 [8] (faster with fp16)
|
||||||
|
if self.fp16:
|
||||||
|
n1 = pred_mask.sum().item()
|
||||||
|
if n1 > 8:
|
||||||
|
pred_mask = pred_mask.view(-1)
|
||||||
|
n2 = max(n1 % 8, 8 * (n1 // 8))
|
||||||
|
if n2 != n1:
|
||||||
|
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
|
||||||
|
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||||
|
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||||
|
|
||||||
|
_token_ids_real = token_ids[pred_mask]
|
||||||
|
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size)
|
||||||
|
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
|
||||||
|
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||||
|
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||||
|
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||||
|
|
||||||
|
mlm_labels[1-pred_mask] = -1
|
||||||
|
|
||||||
|
return token_ids, attn_mask, mlm_labels
|
||||||
|
|
||||||
|
def round_batch(self,
|
||||||
|
x: torch.tensor,
|
||||||
|
lengths: torch.tensor):
|
||||||
|
"""
|
||||||
|
For float16 only.
|
||||||
|
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
------
|
||||||
|
x: `torch.tensor(bs, seq_length)` - The token ids.
|
||||||
|
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
-------
|
||||||
|
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
|
||||||
|
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
|
||||||
|
"""
|
||||||
|
if not self.fp16 or len(lengths) < 8:
|
||||||
|
return x, lengths
|
||||||
|
|
||||||
|
# number of sentences == 0 [8]
|
||||||
|
bs1 = len(lengths)
|
||||||
|
bs2 = 8 * (bs1 // 8)
|
||||||
|
assert bs2 > 0 and bs2 % 8 == 0
|
||||||
|
if bs1 != bs2:
|
||||||
|
idx = torch.randperm(bs1)[:bs2]
|
||||||
|
lengths = lengths[idx]
|
||||||
|
slen = lengths.max().item()
|
||||||
|
x = x[idx, :slen]
|
||||||
|
else:
|
||||||
|
idx = None
|
||||||
|
|
||||||
|
# sequence length == 0 [8]
|
||||||
|
ml1 = x.size(1)
|
||||||
|
if ml1 % 8 != 0:
|
||||||
|
pad = 8 - (ml1 % 8)
|
||||||
|
ml2 = ml1 + pad
|
||||||
|
pad_id = self.params.special_tok_ids['pad_token']
|
||||||
|
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||||
|
x = torch.cat([x, padding_tensor], 1)
|
||||||
|
assert x.size() == (bs2, ml2)
|
||||||
|
|
||||||
|
assert x.size(0) % 8 == 0
|
||||||
|
assert x.size(1) % 8 == 0
|
||||||
|
return x, lengths
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""
|
||||||
|
The real training loop.
|
||||||
|
"""
|
||||||
|
if self.is_master: logger.info('Starting training')
|
||||||
|
self.student.train()
|
||||||
|
self.teacher.eval()
|
||||||
|
|
||||||
|
for _ in range(self.params.n_epoch):
|
||||||
|
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||||
|
|
||||||
|
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||||
|
for __ in range(self.num_steps_epoch):
|
||||||
|
batch = self.get_batch()
|
||||||
|
if self.params.n_gpu > 0:
|
||||||
|
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
|
||||||
|
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
|
||||||
|
|
||||||
|
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels)
|
||||||
|
|
||||||
|
iter_bar.update()
|
||||||
|
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
|
||||||
|
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
|
||||||
|
iter_bar.close()
|
||||||
|
|
||||||
|
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||||
|
self.end_epoch()
|
||||||
|
|
||||||
|
if self.is_master: logger.info('Training is finished')
|
||||||
|
|
||||||
|
def step(self,
|
||||||
|
input_ids: torch.tensor,
|
||||||
|
attention_mask: torch.tensor,
|
||||||
|
mlm_labels: torch.tensor):
|
||||||
|
"""
|
||||||
|
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||||
|
and possibly a parameter update (depending on the gradient accumulation).
|
||||||
|
|
||||||
|
Input:
|
||||||
|
------
|
||||||
|
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
|
||||||
|
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
|
||||||
|
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
|
||||||
|
"""
|
||||||
|
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||||
|
with torch.no_grad():
|
||||||
|
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size)
|
||||||
|
assert s_logits.size() == t_logits.size()
|
||||||
|
|
||||||
|
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||||
|
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||||
|
if self.params.restrict_ce_to_mask:
|
||||||
|
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||||
|
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
|
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
|
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
|
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
|
assert t_logits_slct.size() == s_logits_slct.size()
|
||||||
|
|
||||||
|
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
|
||||||
|
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
|
||||||
|
loss = self.alpha_ce*loss_ce
|
||||||
|
if self.alpha_mlm > 0.:
|
||||||
|
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))
|
||||||
|
loss += self.alpha_mlm * loss_mlm
|
||||||
|
if self.alpha_mse > 0.:
|
||||||
|
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
||||||
|
loss += self.alpha_mse * loss_mse
|
||||||
|
|
||||||
|
self.total_loss_epoch += loss.item()
|
||||||
|
self.last_loss = loss.item()
|
||||||
|
self.last_loss_ce = loss_ce.item()
|
||||||
|
if self.alpha_mlm > 0.:
|
||||||
|
self.last_loss_mlm = loss_mlm.item()
|
||||||
|
if self.alpha_mse > 0.:
|
||||||
|
self.last_loss_mse = loss_mse.item()
|
||||||
|
|
||||||
|
self.optimize(loss)
|
||||||
|
|
||||||
|
self.n_sequences_epoch += input_ids.size(0)
|
||||||
|
|
||||||
|
def optimize(self,
|
||||||
|
loss):
|
||||||
|
"""
|
||||||
|
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||||
|
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||||
|
Also update the metrics for tensorboard.
|
||||||
|
"""
|
||||||
|
# Check for NaN
|
||||||
|
if (loss != loss).data.any():
|
||||||
|
logger.error('NaN detected')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if self.multi_gpu:
|
||||||
|
loss = loss.mean()
|
||||||
|
if self.params.gradient_accumulation_steps > 1:
|
||||||
|
loss = loss / self.params.gradient_accumulation_steps
|
||||||
|
|
||||||
|
if self.fp16:
|
||||||
|
from apex import amp
|
||||||
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self.iter()
|
||||||
|
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
||||||
|
if self.fp16:
|
||||||
|
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||||
|
else:
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||||
|
self.scheduler.step()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def iter(self):
|
||||||
|
"""
|
||||||
|
Update global counts, write to tensorboard and save checkpoint.
|
||||||
|
"""
|
||||||
|
self.n_iter += 1
|
||||||
|
self.n_total_iter += 1
|
||||||
|
|
||||||
|
if self.n_total_iter % self.params.log_interval == 0:
|
||||||
|
self.log_tensorboard()
|
||||||
|
if self.n_total_iter % self.params.checkpoint_interval == 0:
|
||||||
|
self.save_checkpoint()
|
||||||
|
|
||||||
|
def log_tensorboard(self):
|
||||||
|
"""
|
||||||
|
Log into tensorboard. Only by the master process.
|
||||||
|
"""
|
||||||
|
if not self.is_master:
|
||||||
|
return
|
||||||
|
|
||||||
|
for param_name, param in self.student.named_parameters():
|
||||||
|
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
|
||||||
|
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
|
||||||
|
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)
|
||||||
|
|
||||||
|
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
|
||||||
|
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||||
|
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
|
||||||
|
if self.alpha_mlm > 0.:
|
||||||
|
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
||||||
|
if self.alpha_mse > 0.:
|
||||||
|
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||||
|
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||||
|
|
||||||
|
def end_epoch(self):
|
||||||
|
"""
|
||||||
|
Finally arrived at the end of epoch (full pass on dataset).
|
||||||
|
Do some tensorboard logging and checkpoint saving.
|
||||||
|
"""
|
||||||
|
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
|
||||||
|
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)
|
||||||
|
|
||||||
|
self.epoch += 1
|
||||||
|
self.n_sequences_epoch = 0
|
||||||
|
self.n_iter = 0
|
||||||
|
self.total_loss_epoch = 0
|
||||||
|
|
||||||
|
def save_checkpoint(self,
|
||||||
|
checkpoint_name: str = 'checkpoint.pth'):
|
||||||
|
"""
|
||||||
|
Save the current state. Only by the master process.
|
||||||
|
"""
|
||||||
|
if not self.is_master:
|
||||||
|
return
|
||||||
|
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
|
||||||
|
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||||
|
state_dict = mdl_to_save.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
112
examples/distillation/utils.py
Normal file
112
examples/distillation/utils.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import git
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
|
||||||
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
level = logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def git_log(folder_path: str):
|
||||||
|
"""
|
||||||
|
Log commit info.
|
||||||
|
"""
|
||||||
|
repo = git.Repo(search_parent_directories=True)
|
||||||
|
repo_infos = {
|
||||||
|
'repo_id': str(repo),
|
||||||
|
'repo_sha': str(repo.head.object.hexsha),
|
||||||
|
'repo_branch': str(repo.active_branch)
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
|
||||||
|
json.dump(repo_infos, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def init_gpu_params(params):
|
||||||
|
"""
|
||||||
|
Handle single and multi-GPU / multi-node.
|
||||||
|
"""
|
||||||
|
if params.n_gpu <= 0:
|
||||||
|
params.local_rank = 0
|
||||||
|
params.master_port = -1
|
||||||
|
params.is_master = True
|
||||||
|
params.multi_gpu = False
|
||||||
|
return
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
|
||||||
|
logger.info('Initializing GPUs')
|
||||||
|
if params.n_gpu > 1:
|
||||||
|
assert params.local_rank != -1
|
||||||
|
|
||||||
|
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
|
||||||
|
params.global_rank = int(os.environ['RANK'])
|
||||||
|
|
||||||
|
# number of nodes / node ID
|
||||||
|
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||||
|
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||||
|
params.multi_gpu = True
|
||||||
|
|
||||||
|
assert params.n_nodes == int(os.environ['N_NODES'])
|
||||||
|
assert params.node_id == int(os.environ['NODE_RANK'])
|
||||||
|
|
||||||
|
# local job (single GPU)
|
||||||
|
else:
|
||||||
|
assert params.local_rank == -1
|
||||||
|
|
||||||
|
params.n_nodes = 1
|
||||||
|
params.node_id = 0
|
||||||
|
params.local_rank = 0
|
||||||
|
params.global_rank = 0
|
||||||
|
params.world_size = 1
|
||||||
|
params.n_gpu_per_node = 1
|
||||||
|
params.multi_gpu = False
|
||||||
|
|
||||||
|
# sanity checks
|
||||||
|
assert params.n_nodes >= 1
|
||||||
|
assert 0 <= params.node_id < params.n_nodes
|
||||||
|
assert 0 <= params.local_rank <= params.global_rank < params.world_size
|
||||||
|
assert params.world_size == params.n_nodes * params.n_gpu_per_node
|
||||||
|
|
||||||
|
# define whether this is the master process / if we are in multi-node distributed mode
|
||||||
|
params.is_master = params.node_id == 0 and params.local_rank == 0
|
||||||
|
params.multi_node = params.n_nodes > 1
|
||||||
|
|
||||||
|
# summary
|
||||||
|
PREFIX = f"--- Global rank: {params.global_rank} - "
|
||||||
|
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
|
||||||
|
logger.info(PREFIX + "Node ID : %i" % params.node_id)
|
||||||
|
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
|
||||||
|
logger.info(PREFIX + "World size : %i" % params.world_size)
|
||||||
|
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
|
||||||
|
logger.info(PREFIX + "Master : %s" % str(params.is_master))
|
||||||
|
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
|
||||||
|
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
|
||||||
|
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
|
||||||
|
|
||||||
|
# set GPU device
|
||||||
|
torch.cuda.set_device(params.local_rank)
|
||||||
|
|
||||||
|
# initialize multi-GPU
|
||||||
|
if params.multi_gpu:
|
||||||
|
logger.info("Initializing PyTorch distributed")
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
init_method='env://',
|
||||||
|
backend='nccl',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(args):
|
||||||
|
"""
|
||||||
|
Set the random seed.
|
||||||
|
"""
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
if args.n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
Loading…
Reference in New Issue
Block a user