[examples] fix no grad in second pruning in run_bertology (#4479)

* fix no grad in second pruning and typo

* fix prune heads attention mismatch problem

* fix

* fix

* fix

* run make style

* run make style
This commit is contained in:
Tobias Lee 2020-05-21 21:17:03 +08:00 committed by GitHub
parent 865d4d595e
commit 271bedb485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,7 +64,7 @@ def print_2d_tensor(tensor):
def compute_heads_importance(
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
):
""" This method shows how to compute:
- head attention entropy
@ -77,7 +77,12 @@ def compute_heads_importance(
if head_mask is None:
head_mask = torch.ones(n_layers, n_heads).to(args.device)
head_mask.requires_grad_(requires_grad=True)
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
if actually_pruned:
head_mask = None
preds = None
labels = None
tot_tokens = 0.0
@ -172,6 +177,7 @@ def mask_heads(args, model, eval_dataloader):
new_head_mask = new_head_mask.view(-1)
new_head_mask[current_heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_mask)
new_head_mask = new_head_mask.clone().detach()
print_2d_tensor(new_head_mask)
# Compute metric and head importance again
@ -181,7 +187,7 @@ def mask_heads(args, model, eval_dataloader):
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
logger.info(
"Masking: current score: %f, remaning heads %d (%.1f percents)",
"Masking: current score: %f, remaining heads %d (%.1f percents)",
current_score,
new_head_mask.sum(),
new_head_mask.sum() / new_head_mask.numel() * 100,
@ -209,14 +215,23 @@ def prune_heads(args, model, eval_dataloader, head_mask):
original_time = datetime.now() - before_time
original_num_params = sum(p.numel() for p in model.parameters())
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
heads_to_prune = dict(
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
)
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
model.prune_heads(heads_to_prune)
pruned_num_params = sum(p.numel() for p in model.parameters())
before_time = datetime.now()
_, _, preds, labels = compute_heads_importance(
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
args,
model,
eval_dataloader,
compute_entropy=False,
compute_importance=False,
head_mask=None,
actually_pruned=True,
)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]