diff --git a/examples/bertology.py b/examples/bertology.py index bf32e8e1747..4bb23b8f168 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -25,17 +25,20 @@ def entropy(p): plogp[p == 0] = 0 return -plogp.sum(dim=-1) + def print_1d_tensor(tensor, prefix=""): if tensor.dtype != torch.long: logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data)) else: logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data)) + def print_2d_tensor(tensor): logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor)))) for row in range(len(tensor)): print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t") + def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None): """ Example on how to use model outputs to compute: - head attention entropy (activated by setting output_attentions=True when we created the model @@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, batch = tuple(t.to(args.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch - # Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below) + # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below) all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask) if compute_entropy: @@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, return attn_entropy, head_importance, preds, labels + def run_model(): parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint') @@ -212,7 +216,7 @@ def run_model(): eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.data_subset > 0: - eval_data = Subset(eval_data, list(range(args.data_subset))) + eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data))))) eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) @@ -246,14 +250,14 @@ def run_model(): logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) new_head_mask = torch.ones_like(head_importance) - num_to_mask = int(new_head_mask.numel() * args.masking_amount) + num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount)) current_score = original_score while current_score >= original_score * args.masking_threshold: - head_mask = new_head_mask # save current head mask - # heads from most important to least - keep only not-masked heads - head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0] - current_heads_to_mask = head_importance.sort()[1] + head_mask = new_head_mask.clone() # save current head mask + # heads from least important to most - keep only not-masked heads + head_importance[head_mask == 0.0] = float('Inf') + current_heads_to_mask = head_importance.view(-1).sort()[1] if len(current_heads_to_mask) <= num_to_mask: break @@ -261,7 +265,7 @@ def run_model(): # mask heads current_heads_to_mask = current_heads_to_mask[:num_to_mask] logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist())) - new_head_mask = head_mask.view(-1) + new_head_mask = new_head_mask.view(-1) new_head_mask[current_heads_to_mask] = 0.0 new_head_mask = new_head_mask.view_as(head_mask) print_2d_tensor(new_head_mask) @@ -272,6 +276,10 @@ def run_model(): current_score = compute_metrics(task_name, preds, labels)[args.metric_name] logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100) + logger.info("Final head mask") + print_2d_tensor(head_mask) + np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy()) + # Try pruning and test time speedup # Pruning is like masking but we actually remove the masked weights before_time = datetime.now()