mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[examples] fix no grad in second pruning in run_bertology (#4479)
* fix no grad in second pruning and typo * fix prune heads attention mismatch problem * fix * fix * fix * run make style * run make style
This commit is contained in:
parent
865d4d595e
commit
271bedb485
@ -64,7 +64,7 @@ def print_2d_tensor(tensor):
|
|||||||
|
|
||||||
|
|
||||||
def compute_heads_importance(
|
def compute_heads_importance(
|
||||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
|
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||||
):
|
):
|
||||||
""" This method shows how to compute:
|
""" This method shows how to compute:
|
||||||
- head attention entropy
|
- head attention entropy
|
||||||
@ -77,7 +77,12 @@ def compute_heads_importance(
|
|||||||
|
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||||
|
|
||||||
head_mask.requires_grad_(requires_grad=True)
|
head_mask.requires_grad_(requires_grad=True)
|
||||||
|
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||||
|
if actually_pruned:
|
||||||
|
head_mask = None
|
||||||
|
|
||||||
preds = None
|
preds = None
|
||||||
labels = None
|
labels = None
|
||||||
tot_tokens = 0.0
|
tot_tokens = 0.0
|
||||||
@ -172,6 +177,7 @@ def mask_heads(args, model, eval_dataloader):
|
|||||||
new_head_mask = new_head_mask.view(-1)
|
new_head_mask = new_head_mask.view(-1)
|
||||||
new_head_mask[current_heads_to_mask] = 0.0
|
new_head_mask[current_heads_to_mask] = 0.0
|
||||||
new_head_mask = new_head_mask.view_as(head_mask)
|
new_head_mask = new_head_mask.view_as(head_mask)
|
||||||
|
new_head_mask = new_head_mask.clone().detach()
|
||||||
print_2d_tensor(new_head_mask)
|
print_2d_tensor(new_head_mask)
|
||||||
|
|
||||||
# Compute metric and head importance again
|
# Compute metric and head importance again
|
||||||
@ -181,7 +187,7 @@ def mask_heads(args, model, eval_dataloader):
|
|||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||||
logger.info(
|
logger.info(
|
||||||
"Masking: current score: %f, remaning heads %d (%.1f percents)",
|
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||||
current_score,
|
current_score,
|
||||||
new_head_mask.sum(),
|
new_head_mask.sum(),
|
||||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||||
@ -209,14 +215,23 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
original_time = datetime.now() - before_time
|
original_time = datetime.now() - before_time
|
||||||
|
|
||||||
original_num_params = sum(p.numel() for p in model.parameters())
|
original_num_params = sum(p.numel() for p in model.parameters())
|
||||||
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
|
heads_to_prune = dict(
|
||||||
|
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||||
|
)
|
||||||
|
|
||||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||||
model.prune_heads(heads_to_prune)
|
model.prune_heads(heads_to_prune)
|
||||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
before_time = datetime.now()
|
before_time = datetime.now()
|
||||||
_, _, preds, labels = compute_heads_importance(
|
_, _, preds, labels = compute_heads_importance(
|
||||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
|
args,
|
||||||
|
model,
|
||||||
|
eval_dataloader,
|
||||||
|
compute_entropy=False,
|
||||||
|
compute_importance=False,
|
||||||
|
head_mask=None,
|
||||||
|
actually_pruned=True,
|
||||||
)
|
)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||||
|
Loading…
Reference in New Issue
Block a user