mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update DistilBERT training code
This commit is contained in:
parent
f9453d15e5
commit
dddd6b9927
@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l
|
||||
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
|
||||
).
|
||||
|
||||
## Setup
|
||||
|
||||
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
|
||||
|
||||
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/pytorch-transformers/issues/1179) for more details.
|
||||
|
||||
## How to use DistilBERT
|
||||
|
||||
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
|
@ -17,6 +17,7 @@
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
import psutil
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import trange, tqdm
|
||||
import numpy as np
|
||||
@ -192,7 +193,7 @@ class Distiller:
|
||||
x_prob = self.token_probs[token_ids.flatten()]
|
||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device)
|
||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||
pred_mask[tgt_ids] = 1
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
|
||||
@ -216,7 +217,7 @@ class Distiller:
|
||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||
|
||||
mlm_labels[1-pred_mask] = -1
|
||||
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
@ -379,9 +380,9 @@ class Distiller:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||
self.scheduler.step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
def iter(self):
|
||||
"""
|
||||
@ -418,6 +419,8 @@ class Distiller:
|
||||
if self.alpha_mse > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||
|
||||
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
|
@ -1 +1,4 @@
|
||||
gitpython==3.0.2
|
||||
tensorboard>=1.14.0
|
||||
tensorboardX==1.8
|
||||
psutil==5.6.3
|
||||
|
Loading…
Reference in New Issue
Block a user