diff --git a/examples/research_projects/zero-shot-distillation/README.md b/examples/research_projects/zero-shot-distillation/README.md new file mode 100644 index 00000000000..cf20cb40bcd --- /dev/null +++ b/examples/research_projects/zero-shot-distillation/README.md @@ -0,0 +1,155 @@ +# Zero-shot classifier distillation + +Author: @joeddav + +This script provides a way to improve the speed and memory performance of a zero-shot classifier by training a more +efficient student model from the zero-shot teacher's predictions over an unlabeled dataset. + +The zero-shot classification pipeline uses a model pre-trained on natural language inference (NLI) to determine the +compatibility of a set of candidate class names with a given sequence. This serves as a convenient out-of-the-box +classifier without the need for labeled training data. However, for a given sequence, the method requires each +possible label to be fed through the large NLI model separately. Thus for `N` sequences and `K` classes, a total of +`N*K` forward passes through the model are required. This requirement slows inference considerably, particularly as +`K` grows. + +Given (1) an unlabeled corpus and (2) a set of candidate class names, the provided script trains a student model +with a standard classification head with `K` output dimensions. The resulting student model can then be used for +classifying novel text instances with a significant boost in speed and memory performance while retaining similar +classification performance to the original zero-shot model + +### Usage + +A teacher NLI model can be distilled to a more efficient student model by running `distill_classifier.py`: + +``` +python distill_classifier.py \ +--data_file \ +--class_names_file \ +--output_dir +``` + +`` should be a text file with a single unlabeled example per line. `` is a text file with one class name per line. + +Other optional arguments include: + +- `--teacher_name_or_path` (default: `roberta-large-mnli`): The name or path of the NLI teacher model. +- `--student_name_or_path` (default: `distillbert-base-uncased`): The name or path of the student model which will +be fine-tuned to copy the teacher predictions. +- `--hypothesis_template` (default `"This example is {}."`): The template used to turn each label into an NLI-style +hypothesis when generating teacher predictions. This template must include a `{}` or similar syntax for the +candidate label to be inserted into the template. For example, the default template is `"This example is {}."` With +the candidate label `sports`, this would be fed into the model like `[CLS] sequence to classify [SEP] This example +is sports . [SEP]`. +- `--multi_class`: Whether or not multiple candidate labels can be true. By default, the scores are normalized such +that the sum of the label likelihoods for each sequence is 1. If `--multi_class` is passed, the labels are +considered independent and probabilities are normalized for each candidate by doing a softmax of the entailment +score vs. the contradiction score. This is sometimes called "multi-class multi-label" classification. +- `--temperature` (default: `1.0`): The temperature applied to the softmax of the teacher model predictions. A +higher temperature results in a student with smoother (lower confidence) predictions than the teacher while a value +`<1` resultings in a higher-confidence, peaked distribution. The default `1.0` is equivalent to no smoothing. +- `--teacher_batch_size` (default: `32`): The batch size used for generating a single set of teacher predictions. +Does not affect training. Use `--per_device_train_batch_size` to change the training batch size. + +Any of the arguments in the 🤗 Trainer's +[`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html?#trainingarguments) can also be +modified, such as `--learning_rate`, `--fp16`, `--no_cuda`, `--warmup_steps`, etc. Run `python distill_classifier.py +-h` for a full list of available arguments or consult the [Trainer +documentation](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments). + +> **Note**: Distributed and TPU training are not currently supported. Single-node multi-GPU is supported, however, +and will run automatically if multiple GPUs are available. + +### Example: Topic classification + +> A full colab demo notebook of this example can be found [here](https://colab.research.google.com/drive/1mjBjd0cR8G57ZpsnFCS3ngGyo5nCa9ya?usp=sharing). + +Let's say we're interested in classifying news articles into one of four topic categories: "the world", "sports", +"business", or "science/tech". We have an unlabeled dataset, [AG's News](https://huggingface.co/datasets/ag_news), +which corresponds to this problem (in reality AG's News is annotated, but we will pretend it is not for the sake of +example). + +We can use an NLI model like `roberta-large-mnli` for zero-shot classification like so: + +```python +>>> class_names = ["the world", "sports", "business", "science/tech"] +>>> hypothesis_template = "This text is about {}." +>>> sequence = "A new moon has been discovered in Jupiter's orbit" + +>>> zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli") +>>> zero_shot_classifier(sequence, class_names, hypothesis_template=hypothesis_template) +{'sequence': "A new moon has been discovered in Jupiter's orbit", + 'labels': ['science/tech', 'the world', 'business', 'sports'], + 'scores': [0.7035840153694153, 0.18744826316833496, 0.06027870625257492, 0.04868902638554573]} +``` + +Unfortunately, inference is slow since each of our 4 class names must be fed through the large model for every +sequence to be classified. But with our unlabeled data we can distill the model to a small distilbert classifier to +make future inference much faster. + +To run the script, we will need to put each training example (text only) from AG's News on its own line in +`agnews/train_unlabeled.txt`, and each of the four class names in the newline-separated `agnews/class_names.txt`. +Then we can run distillation with the following command: + +```bash +python distill_classifier.py \ +--data_file ./agnews/unlabeled.txt \ +--class_names_files ./agnews/class_names.txt \ +--teacher_name_or_path roberta-large-mnli \ +--hypothesis_template "This text is about {}." \ +--output_dir ./agnews/distilled +``` + +The script will generate a set of soft zero-shot predictions from `roberta-large-mnli` for each example in +`agnews/unlabeled.txt`. It will then train a student distilbert classifier on the teacher predictions and +save the resulting model in `./agnews/distilled`. + +The resulting model can then be loaded and used like any other pre-trained classifier: + +```python +from transformers import AutoModelForSequenceClassification, AutoTokenizer +model = AutoModelForSequenceClassification.from_pretrained("./agnews/distilled") +tokenizer = AutoTokenizer.from_pretrained("./agnews/distilled") +``` + +and even used trivially with a `TextClassificationPipeline`: + +```python +>>> distilled_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) +>>> distilled_classifier(sequence) +[[{'label': 'the world', 'score': 0.14899294078350067}, + {'label': 'sports', 'score': 0.03205857425928116}, + {'label': 'business', 'score': 0.05943061783909798}, + {'label': 'science/tech', 'score': 0.7595179080963135}]] +``` + +> Tip: pass `device=0` when constructing a pipeline to run on a GPU + +As we can see, the results of the student closely resemble that of the trainer despite never having seen this +example during training. Now let's do a quick & dirty speed comparison simulating 16K examples with a batch size of +16: + +```python +for _ in range(1000): + zero_shot_classifier([sequence] * 16, class_names) +# runs in 1m 23s on a single V100 GPU +``` + +```python +%%time +for _ in range(1000): + distilled_classifier([sequence] * 16) +# runs in 10.3s on a single V100 GPU +``` + +As we can see, the distilled student model runs an order of magnitude faster than its teacher NLI model. This is +also a seeting where we only have `K=4` possible labels. The higher the number of classes for a given task, the more +drastic the speedup will be, since the zero-shot teacher's complexity scales linearly with the number of classes. + +Since we secretly have access to ground truth labels for AG's news, we can evaluate the accuracy of each model. The +original zero-shot model `roberta-large-mnli` gets an accuracy of 69.3% on the held-out test set. After training a +student on the unlabeled training set, the distilled model gets a similar score of 70.4%. + +Lastly, you can share the distilled model with the community and/or use it with our inference API by [uploading it +to the 🤗 Hub](https://huggingface.co/transformers/model_sharing.html). We've uploaded the distilled model from this +example at +[joeddav/distilbert-base-uncased-agnews-student](https://huggingface.co/joeddav/distilbert-base-uncased-agnews-student). diff --git a/examples/research_projects/zero-shot-distillation/distill_classifier.py b/examples/research_projects/zero-shot-distillation/distill_classifier.py new file mode 100644 index 00000000000..f1603876190 --- /dev/null +++ b/examples/research_projects/zero-shot-distillation/distill_classifier.py @@ -0,0 +1,338 @@ +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +from datasets import Dataset +from torch import nn +from tqdm.auto import tqdm + +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + set_seed, + utils, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process + + +DESCRIPTION = """ +Distills an NLI-based zero-shot classifier to a smaller, more efficient model with a fixed set of candidate class +names. Useful for speeding up zero-shot classification in cases where labeled training data is not available, but +when only a single fixed set of classes is needed. Takes a teacher NLI model, student classifier model, unlabeled +dataset, and set of K possible class names. Yields a single classifier with K outputs corresponding to the provided +class names. +""" + +logger = logging.getLogger(__name__) + + +@dataclass +class TeacherModelArguments: + teacher_name_or_path: Optional[str] = field( + default="roberta-large-mnli", metadata={"help": "The NLI/zero-shot teacher model to be distilled."} + ) + hypothesis_template: Optional[str] = field( + default="This example is {}.", + metadata={ + "help": ( + "Template used to turn class names into mock hypotheses for teacher NLI model. Must include {{}}" + "where class name is inserted." + ) + }, + ) + teacher_batch_size: Optional[int] = field( + default=32, metadata={"help": "Batch size for generating teacher predictions."} + ) + multi_class: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Allow multiple classes to be true rather than forcing them to sum to 1 (sometimes called" + "multi-class multi-label classification)." + ) + }, + ) + temperature: Optional[float] = field( + default=1.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} + ) + + +@dataclass +class StudentModelArguments: + student_name_or_path: Optional[str] = field( + default="distilbert-base-uncased", metadata={"help": "The NLI/zero-shot teacher model to be distilled."} + ) + + +@dataclass +class DataTrainingArguments: + data_file: str = field(metadata={"help": "Text file with one unlabeled instance per line."}) + class_names_file: str = field(metadata={"help": "Text file with one class name per line."}) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the Rust tokenizers library) or not."}, + ) + + +@dataclass +class DistillTrainingArguments(TrainingArguments): + output_dir: Optional[str] = field( + default=None, + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + per_device_train_batch_size: int = field( + default=32, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=128, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} + ) + num_train_epochs: float = field(default=1.0, metadata={"help": "Total number of training epochs to perform."}) + do_train: bool = field(default=True, metadata={"help": "Whether to run training of student model."}) + do_eval: bool = field( + default=True, + metadata={ + "help": ( + "Whether to evaluate the agreement of the final student predictions and the teacher predictions" + "after training." + ) + }, + ) + save_total_limit: Optional[int] = field( + default=0, + metadata={ + "help": ( + "Limit the total amount of checkpoints." + "Deletes the older checkpoints in the output_dir. Default is 0 (no checkpoints)." + ) + }, + ) + + +class DistillationTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + target_p = inputs["labels"] + outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]) + logits = outputs[0] + + loss = -torch.sum(target_p * logits.log_softmax(dim=-1), axis=-1).mean() + + if return_outputs: + return loss, outputs + + return loss + + +def read_lines(path): + lines = [] + with open(path, "r") as f: + for line in f: + line = line.strip() + if len(line) > 0: + lines.append(line) + return lines + + +def get_premise_hypothesis_pairs(examples, class_names, hypothesis_template): + premises = [] + hypotheses = [] + for example in examples: + for name in class_names: + premises.append(example) + hypotheses.append(hypothesis_template.format(name)) + return premises, hypotheses + + +def get_entailment_id(config): + for label, ind in config.label2id.items(): + if label.lower().startswith("entail"): + return ind + logging.warning("Could not identify entailment dimension from teacher config label2id. Setting to -1.") + return -1 + + +def get_teacher_predictions( + model_path: str, + examples: List[str], + class_names: List[str], + hypothesis_template: str, + batch_size: int, + temperature: float, + multi_class: bool, + use_fast_tokenizer: bool, + no_cuda: bool, + fp16: bool, +): + """ + Gets predictions by the same method as the zero-shot pipeline but with DataParallel & more efficient batching + """ + model = AutoModelForSequenceClassification.from_pretrained(model_path) + model_config = model.config + if not no_cuda and torch.cuda.is_available(): + model = nn.DataParallel(model) + batch_size *= len(model.device_ids) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer) + + premises, hypotheses = get_premise_hypothesis_pairs(examples, class_names, hypothesis_template) + logits = [] + + for i in tqdm(range(0, len(premises), batch_size)): + batch_premises = premises[i : i + batch_size] + batch_hypotheses = hypotheses[i : i + batch_size] + + encodings = tokenizer( + batch_premises, + batch_hypotheses, + padding=True, + truncation="only_first", + return_tensors="pt", + ) + + with torch.cuda.amp.autocast(enabled=fp16): + with torch.no_grad(): + outputs = model(**encodings) + logits.append(outputs.logits.detach().cpu().float()) + + entail_id = get_entailment_id(model_config) + contr_id = -1 if entail_id == 0 else 0 + logits = torch.cat(logits, dim=0) # N*K x 3 + nli_logits = logits.reshape(len(examples), len(class_names), -1)[..., [contr_id, entail_id]] # N x K x 2 + + if multi_class: + # softmax over (contr, entail) logits for each class independently + nli_prob = (nli_logits / temperature).softmax(-1) + else: + # softmax over entail logits across classes s.t. class probabilities sum to 1. + nli_prob = (nli_logits / temperature).softmax(1) + + return nli_prob[..., 1] # N x K + + +def main(): + parser = HfArgumentParser( + (DataTrainingArguments, TeacherModelArguments, StudentModelArguments, DistillTrainingArguments), + description=DESCRIPTION, + ) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + data_args, teacher_args, student_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + data_args, teacher_args, student_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + utils.logging.set_verbosity_info() + utils.logging.enable_default_handler() + utils.logging.enable_explicit_format() + + if training_args.local_rank != -1: + raise ValueError("Distributed training is not currently supported.") + if training_args.tpu_num_cores is not None: + raise ValueError("TPU acceleration is not currently supported.") + + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # 1. read in data + examples = read_lines(data_args.data_file) + class_names = read_lines(data_args.class_names_file) + + # 2. get teacher predictions and load into dataset + logger.info("Generating predictions from zero-shot teacher model") + teacher_soft_preds = get_teacher_predictions( + teacher_args.teacher_name_or_path, + examples, + class_names, + teacher_args.hypothesis_template, + teacher_args.teacher_batch_size, + teacher_args.temperature, + teacher_args.multi_class, + data_args.use_fast_tokenizer, + training_args.no_cuda, + training_args.fp16, + ) + dataset = Dataset.from_dict( + { + "text": examples, + "labels": teacher_soft_preds, + } + ) + + # 3. create student + logger.info("Initializing student model") + model = AutoModelForSequenceClassification.from_pretrained( + student_args.student_name_or_path, num_labels=len(class_names) + ) + tokenizer = AutoTokenizer.from_pretrained(student_args.student_name_or_path, use_fast=data_args.use_fast_tokenizer) + model.config.id2label = {i: label for i, label in enumerate(class_names)} + model.config.label2id = {label: i for i, label in enumerate(class_names)} + + # 4. train student on teacher predictions + dataset = dataset.map(tokenizer, input_columns="text") + dataset.set_format("torch") + + def compute_metrics(p, return_outputs=False): + preds = p.predictions.argmax(-1) + proxy_labels = p.label_ids.argmax(-1) # "label_ids" are actually distributions + return {"agreement": (preds == proxy_labels).mean().item()} + + trainer = DistillationTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=dataset, + compute_metrics=compute_metrics, + ) + + if training_args.do_train: + logger.info("Training student model on teacher predictions") + trainer.train() + + if training_args.do_eval: + agreement = trainer.evaluate(eval_dataset=dataset)["eval_agreement"] + logger.info(f"Agreement of student and teacher predictions: {agreement * 100:0.2f}%") + + trainer.save_model() + + +if __name__ == "__main__": + main()