update head pruning

This commit is contained in:
thomwolf 2019-06-19 22:16:30 +02:00
parent 0f40e8d6a6
commit e4b46d86ce

View File

@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
# Normalize
attn_entropy /= tot_tokens
head_importance /= tot_tokens
if args.normalize_importance:
# Layerwise importance normalization
if not args.dont_normalize_importance_by_layer:
exponent = 2
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
if not args.dont_normalize_global_importance:
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
return attn_entropy, head_importance, preds, labels
@ -106,7 +112,8 @@ def run_model():
parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")
parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers")
parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1")
parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
@ -243,21 +250,20 @@ def run_model():
current_score = original_score
while current_score >= original_score * args.masking_threshold:
head_mask = new_head_mask
# heads from most important to least
heads_to_mask = head_importance.view(-1).sort(descending=True)[1]
# keep only not-masked heads
heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]
head_mask = new_head_mask # save current head mask
# heads from most important to least - keep only not-masked heads
head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0]
current_heads_to_mask = head_importance.sort()[1]
if len(heads_to_mask) <= num_to_mask:
if len(current_heads_to_mask) <= num_to_mask:
break
# mask heads
heads_to_mask = heads_to_mask[-num_to_mask:]
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
new_head_mask = head_mask.view(-1)
new_head_mask[heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_importance)
new_head_mask[current_heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_mask)
print_2d_tensor(new_head_mask)
# Compute metric and head importance again