diff --git a/examples/distillation/README.md b/examples/distillation/README.md index bb919385f17..73e0cc06559 100644 --- a/examples/distillation/README.md +++ b/examples/distillation/README.md @@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 ). +## Setup + +This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`. + +**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/pytorch-transformers/issues/1179) for more details. + ## How to use DistilBERT PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): diff --git a/examples/distillation/distiller.py b/examples/distillation/distiller.py index 38769c4b0ec..ed710a2bee6 100644 --- a/examples/distillation/distiller.py +++ b/examples/distillation/distiller.py @@ -17,6 +17,7 @@ """ import os import math +import psutil from tensorboardX import SummaryWriter from tqdm import trange, tqdm import numpy as np @@ -192,7 +193,7 @@ class Distiller: x_prob = self.token_probs[token_ids.flatten()] n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) - pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device) + pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility pred_mask[tgt_ids] = 1 pred_mask = pred_mask.view(bs, max_seq_len) @@ -216,7 +217,7 @@ class Distiller: _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() token_ids = token_ids.masked_scatter(pred_mask, _token_ids) - mlm_labels[1-pred_mask] = -1 + mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility return token_ids, attn_mask, mlm_labels @@ -379,9 +380,9 @@ class Distiller: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) - self.scheduler.step() self.optimizer.step() self.optimizer.zero_grad() + self.scheduler.step() def iter(self): """ @@ -418,6 +419,8 @@ class Distiller: if self.alpha_mse > 0.: self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) + + self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) def end_epoch(self): """ diff --git a/examples/distillation/requirements.txt b/examples/distillation/requirements.txt index efb369dc438..18146239eb4 100644 --- a/examples/distillation/requirements.txt +++ b/examples/distillation/requirements.txt @@ -1 +1,4 @@ gitpython==3.0.2 +tensorboard>=1.14.0 +tensorboardX==1.8 +psutil==5.6.3