mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
update head pruning
This commit is contained in:
parent
0f40e8d6a6
commit
e4b46d86ce
@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
||||
# Normalize
|
||||
attn_entropy /= tot_tokens
|
||||
head_importance /= tot_tokens
|
||||
if args.normalize_importance:
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
return attn_entropy, head_importance, preds, labels
|
||||
@ -106,7 +112,8 @@ def run_model():
|
||||
parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
|
||||
parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")
|
||||
|
||||
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")
|
||||
parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers")
|
||||
parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1")
|
||||
|
||||
parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
|
||||
parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
|
||||
@ -243,21 +250,20 @@ def run_model():
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask
|
||||
# heads from most important to least
|
||||
heads_to_mask = head_importance.view(-1).sort(descending=True)[1]
|
||||
# keep only not-masked heads
|
||||
heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]
|
||||
head_mask = new_head_mask # save current head mask
|
||||
# heads from most important to least - keep only not-masked heads
|
||||
head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0]
|
||||
current_heads_to_mask = head_importance.sort()[1]
|
||||
|
||||
if len(heads_to_mask) <= num_to_mask:
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
break
|
||||
|
||||
# mask heads
|
||||
heads_to_mask = heads_to_mask[-num_to_mask:]
|
||||
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
|
||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||
new_head_mask = head_mask.view(-1)
|
||||
new_head_mask[heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_importance)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
|
Loading…
Reference in New Issue
Block a user