mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples] fix no grad in second pruning in run_bertology (#4479)
* fix no grad in second pruning and typo * fix prune heads attention mismatch problem * fix * fix * fix * run make style * run make style
This commit is contained in:
parent
865d4d595e
commit
271bedb485
@ -64,7 +64,7 @@ def print_2d_tensor(tensor):
|
||||
|
||||
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
""" This method shows how to compute:
|
||||
- head attention entropy
|
||||
@ -77,7 +77,12 @@ def compute_heads_importance(
|
||||
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||
if actually_pruned:
|
||||
head_mask = None
|
||||
|
||||
preds = None
|
||||
labels = None
|
||||
tot_tokens = 0.0
|
||||
@ -172,6 +177,7 @@ def mask_heads(args, model, eval_dataloader):
|
||||
new_head_mask = new_head_mask.view(-1)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
new_head_mask = new_head_mask.clone().detach()
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
@ -181,7 +187,7 @@ def mask_heads(args, model, eval_dataloader):
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaning heads %d (%.1f percents)",
|
||||
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
@ -209,14 +215,23 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
|
||||
heads_to_prune = dict(
|
||||
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||
)
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
|
||||
args,
|
||||
model,
|
||||
eval_dataloader,
|
||||
compute_entropy=False,
|
||||
compute_importance=False,
|
||||
head_mask=None,
|
||||
actually_pruned=True,
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
|
Loading…
Reference in New Issue
Block a user