Line endings should be LF across repo and not CRLF (#10119)

This commit is contained in:
Lysandre Debut 2021-02-10 16:50:00 +01:00 committed by GitHub
parent 937f67074d
commit 0d8e554d42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 681 additions and 678 deletions

3
.gitattributes vendored Normal file
View File

@ -0,0 +1,3 @@
*.py eol=lf
*.rst eol=lf
*.md eol=lf

View File

@ -1,388 +1,388 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" This script is adapted from the Bertology pruning code (https://github.com/huggingface/transformers/blob/783d7d2629e97c5f0c5f9ef01b8c66410275c204/examples/research_projects/bertology/run_bertology.py) """ This script is adapted from the Bertology pruning code (https://github.com/huggingface/transformers/blob/783d7d2629e97c5f0c5f9ef01b8c66410275c204/examples/research_projects/bertology/run_bertology.py)
to prune GPT-like models. The author is @altsoph. to prune GPT-like models. The author is @altsoph.
""" """
import argparse import argparse
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tqdm import tqdm from tqdm import tqdm
from transformers import GPT2LMHeadModel from transformers import GPT2LMHeadModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_model(model, dirpath): def save_model(model, dirpath):
# save results # save results
if os.path.exists(dirpath): if os.path.exists(dirpath):
if os.path.exists(os.path.join(dirpath, "config.json")) and os.path.isfile( if os.path.exists(os.path.join(dirpath, "config.json")) and os.path.isfile(
os.path.join(dirpath, "config.json") os.path.join(dirpath, "config.json")
): ):
os.remove(os.path.join(dirpath, "config.json")) os.remove(os.path.join(dirpath, "config.json"))
if os.path.exists(os.path.join(dirpath, "pytorch_model.bin")) and os.path.isfile( if os.path.exists(os.path.join(dirpath, "pytorch_model.bin")) and os.path.isfile(
os.path.join(dirpath, "pytorch_model.bin") os.path.join(dirpath, "pytorch_model.bin")
): ):
os.remove(os.path.join(dirpath, "pytorch_model.bin")) os.remove(os.path.join(dirpath, "pytorch_model.bin"))
else: else:
os.makedirs(dirpath) os.makedirs(dirpath)
model.save_pretrained(dirpath) model.save_pretrained(dirpath)
def entropy(p, unlogit=False): def entropy(p, unlogit=False):
""" Compute the entropy of a probability distribution """ """ Compute the entropy of a probability distribution """
exponent = 2 exponent = 2
if unlogit: if unlogit:
p = torch.pow(p, exponent) p = torch.pow(p, exponent)
plogp = p * torch.log(p) plogp = p * torch.log(p)
plogp[p == 0] = 0 plogp[p == 0] = 0
return -plogp.sum(dim=-1) return -plogp.sum(dim=-1)
def print_2d_tensor(tensor): def print_2d_tensor(tensor):
""" Print a 2D tensor """ """ Print a 2D tensor """
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor)))) logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
for row in range(len(tensor)): for row in range(len(tensor)):
if tensor.dtype != torch.long: if tensor.dtype != torch.long:
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data)) logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
else: else:
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data)) logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
def compute_heads_importance( def compute_heads_importance(
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
): ):
"""This method shows how to compute: """This method shows how to compute:
- head attention entropy - head attention entropy
- head importance scores according to http://arxiv.org/abs/1905.10650 - head importance scores according to http://arxiv.org/abs/1905.10650
""" """
# Prepare our tensors # Prepare our tensors
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
head_importance = torch.zeros(n_layers, n_heads).to(args.device) head_importance = torch.zeros(n_layers, n_heads).to(args.device)
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device) attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(n_layers, n_heads).to(args.device) head_mask = torch.ones(n_layers, n_heads).to(args.device)
head_mask.requires_grad_(requires_grad=True) head_mask.requires_grad_(requires_grad=True)
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
if actually_pruned: if actually_pruned:
head_mask = None head_mask = None
tot_tokens = 0.0 tot_tokens = 0.0
total_loss = 0.0 total_loss = 0.0
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
inputs = tuple(t.to(args.device) for t in inputs) inputs = tuple(t.to(args.device) for t in inputs)
(input_ids,) = inputs (input_ids,) = inputs
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below) # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
outputs = model(input_ids, labels=input_ids, head_mask=head_mask) outputs = model(input_ids, labels=input_ids, head_mask=head_mask)
# (loss), lm_logits, presents, (all hidden_states), (attentions) # (loss), lm_logits, presents, (all hidden_states), (attentions)
loss, _, all_attentions = ( loss, _, all_attentions = (
outputs[0], outputs[0],
outputs[1], outputs[1],
outputs[-1], outputs[-1],
) # Loss and logits are the first, attention the last ) # Loss and logits are the first, attention the last
loss.backward() # Backpropagate to populate the gradients in the head mask loss.backward() # Backpropagate to populate the gradients in the head mask
total_loss += loss.detach().cpu().numpy() total_loss += loss.detach().cpu().numpy()
if compute_entropy: if compute_entropy:
for layer, attn in enumerate(all_attentions): for layer, attn in enumerate(all_attentions):
masked_entropy = entropy(attn.detach(), True) masked_entropy = entropy(attn.detach(), True)
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).sum(0).detach() attn_entropy[layer] += masked_entropy.sum(-1).sum(0).sum(0).detach()
if compute_importance: if compute_importance:
head_importance += head_mask.grad.abs().detach() head_importance += head_mask.grad.abs().detach()
tot_tokens += torch.ones_like(input_ids).float().detach().sum().data tot_tokens += torch.ones_like(input_ids).float().detach().sum().data
# Normalize # Normalize
attn_entropy /= tot_tokens attn_entropy /= tot_tokens
head_importance /= tot_tokens head_importance /= tot_tokens
# Layerwise importance normalization # Layerwise importance normalization
if not args.dont_normalize_importance_by_layer: if not args.dont_normalize_importance_by_layer:
exponent = 2 exponent = 2
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent) norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20 head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
if not args.dont_normalize_global_importance: if not args.dont_normalize_global_importance:
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min()) head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
# Print matrices # Print matrices
if compute_entropy: if compute_entropy:
logger.info("Attention entropies") logger.info("Attention entropies")
print_2d_tensor(attn_entropy) print_2d_tensor(attn_entropy)
if compute_importance: if compute_importance:
logger.info("Head importance scores") logger.info("Head importance scores")
print_2d_tensor(head_importance) print_2d_tensor(head_importance)
logger.info("Head ranked by importance scores") logger.info("Head ranked by importance scores")
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device) head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange( head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
head_importance.numel(), device=args.device head_importance.numel(), device=args.device
) )
head_ranks = head_ranks.view_as(head_importance) head_ranks = head_ranks.view_as(head_importance)
print_2d_tensor(head_ranks) print_2d_tensor(head_ranks)
return attn_entropy, head_importance, total_loss return attn_entropy, head_importance, total_loss
def mask_heads(args, model, eval_dataloader): def mask_heads(args, model, eval_dataloader):
"""This method shows how to mask head (set some heads to zero), to test the effect on the network, """This method shows how to mask head (set some heads to zero), to test the effect on the network,
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650) based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
""" """
_, head_importance, loss = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False) _, head_importance, loss = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
original_score = 1 / loss # instead of downsteam score use the LM loss original_score = 1 / loss # instead of downsteam score use the LM loss
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
new_head_mask = torch.ones_like(head_importance) new_head_mask = torch.ones_like(head_importance)
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount)) num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
current_score = original_score current_score = original_score
while current_score >= original_score * args.masking_threshold: while current_score >= original_score * args.masking_threshold:
head_mask = new_head_mask.clone().detach() # save current head mask head_mask = new_head_mask.clone().detach() # save current head mask
# heads from least important to most - keep only not-masked heads # heads from least important to most - keep only not-masked heads
head_importance[head_mask == 0.0] = float("Inf") head_importance[head_mask == 0.0] = float("Inf")
current_heads_to_mask = head_importance.view(-1).sort()[1] current_heads_to_mask = head_importance.view(-1).sort()[1]
if len(current_heads_to_mask) <= num_to_mask: if len(current_heads_to_mask) <= num_to_mask:
print("BREAK BY num_to_mask") print("BREAK BY num_to_mask")
break break
# mask heads # mask heads
current_heads_to_mask = current_heads_to_mask[:num_to_mask] current_heads_to_mask = current_heads_to_mask[:num_to_mask]
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist())) logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
new_head_mask = new_head_mask.view(-1) new_head_mask = new_head_mask.view(-1)
new_head_mask[current_heads_to_mask] = 0.0 new_head_mask[current_heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_mask) new_head_mask = new_head_mask.view_as(head_mask)
new_head_mask = new_head_mask.clone().detach() new_head_mask = new_head_mask.clone().detach()
print_2d_tensor(new_head_mask) print_2d_tensor(new_head_mask)
# Compute metric and head importance again # Compute metric and head importance again
_, head_importance, loss = compute_heads_importance( _, head_importance, loss = compute_heads_importance(
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
) )
current_score = 1 / loss current_score = 1 / loss
logger.info( logger.info(
"Masking: current score: %f, remaining heads %d (%.1f percents)", "Masking: current score: %f, remaining heads %d (%.1f percents)",
current_score, current_score,
new_head_mask.sum(), new_head_mask.sum(),
new_head_mask.sum() / new_head_mask.numel() * 100, new_head_mask.sum() / new_head_mask.numel() * 100,
) )
logger.info("Final head mask") logger.info("Final head mask")
print_2d_tensor(head_mask) print_2d_tensor(head_mask)
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy()) np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
return head_mask return head_mask
def prune_heads(args, model, eval_dataloader, head_mask): def prune_heads(args, model, eval_dataloader, head_mask):
"""This method shows how to prune head (remove heads weights) based on """This method shows how to prune head (remove heads weights) based on
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650) the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
""" """
# Try pruning and test time speedup # Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights # Pruning is like masking but we actually remove the masked weights
before_time = datetime.now() before_time = datetime.now()
_, _, loss = compute_heads_importance( _, _, loss = compute_heads_importance(
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
) )
score_masking = 1 / loss score_masking = 1 / loss
original_time = datetime.now() - before_time original_time = datetime.now() - before_time
original_num_params = sum(p.numel() for p in model.parameters()) original_num_params = sum(p.numel() for p in model.parameters())
heads_to_prune = dict( heads_to_prune = dict(
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask)) (layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
) )
for k, v in heads_to_prune.items(): for k, v in heads_to_prune.items():
if isinstance(v, int): if isinstance(v, int):
heads_to_prune[k] = [ heads_to_prune[k] = [
v, v,
] ]
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
pruned_num_params = sum(p.numel() for p in model.parameters()) pruned_num_params = sum(p.numel() for p in model.parameters())
before_time = datetime.now() before_time = datetime.now()
_, _, loss = compute_heads_importance( _, _, loss = compute_heads_importance(
args, args,
model, model,
eval_dataloader, eval_dataloader,
compute_entropy=False, compute_entropy=False,
compute_importance=False, compute_importance=False,
head_mask=None, head_mask=None,
actually_pruned=True, actually_pruned=True,
) )
score_pruning = 1 / loss score_pruning = 1 / loss
new_time = datetime.now() - before_time new_time = datetime.now() - before_time
logger.info( logger.info(
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
original_num_params, original_num_params,
pruned_num_params, pruned_num_params,
pruned_num_params / original_num_params * 100, pruned_num_params / original_num_params * 100,
) )
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
logger.info("Pruning: speed ratio (original timing / new timing): %f percents", original_time / new_time * 100) logger.info("Pruning: speed ratio (original timing / new timing): %f percents", original_time / new_time * 100)
save_model(model, args.output_dir) save_model(model, args.output_dir)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.", help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
) )
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models", help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
# Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", "--config_name",
default="", default="",
type=str, type=str,
help="Pretrained config name or path if not the same as model_name_or_path", help="Pretrained config name or path if not the same as model_name_or_path",
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
default="", default="",
type=str, type=str,
help="Pretrained tokenizer name or path if not the same as model_name_or_path", help="Pretrained tokenizer name or path if not the same as model_name_or_path",
) )
parser.add_argument( parser.add_argument(
"--cache_dir", "--cache_dir",
default=None, default=None,
type=str, type=str,
help="Where do you want to store the pre-trained models downloaded from s3", help="Where do you want to store the pre-trained models downloaded from s3",
) )
parser.add_argument( parser.add_argument(
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances." "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
) )
parser.add_argument( parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory" "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
) )
parser.add_argument( parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
) )
parser.add_argument( parser.add_argument(
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers" "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
) )
parser.add_argument( parser.add_argument(
"--dont_normalize_global_importance", "--dont_normalize_global_importance",
action="store_true", action="store_true",
help="Don't normalize all importance scores between 0 and 1", help="Don't normalize all importance scores between 0 and 1",
) )
parser.add_argument( parser.add_argument(
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy." "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
) )
parser.add_argument( parser.add_argument(
"--masking_threshold", "--masking_threshold",
default=0.9, default=0.9,
type=float, type=float,
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).", help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
) )
parser.add_argument( parser.add_argument(
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step." "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
) )
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.") parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
parser.add_argument( parser.add_argument(
"--max_seq_length", "--max_seq_length",
default=128, default=128,
type=int, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n" help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, sequences shorter padded.", "Sequences longer than this will be truncated, sequences shorter padded.",
) )
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.") parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
# Setup devices and distributed training # Setup devices and distributed training
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
else: else:
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank) args.device = torch.device("cuda", args.local_rank)
args.n_gpu = 1 args.n_gpu = 1
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
# Setup logging # Setup logging
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1))) logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
# Distributed and parallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
elif args.n_gpu > 1: elif args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Print/save training arguments # Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
torch.save(args, os.path.join(args.output_dir, "run_args.bin")) torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
# Prepare dataset # Prepare dataset
numpy_data = np.concatenate( numpy_data = np.concatenate(
[ [
np.loadtxt(args.data_dir, dtype=np.int64), np.loadtxt(args.data_dir, dtype=np.int64),
] ]
) )
train_tensor_dataset = (torch.from_numpy(numpy_data),) train_tensor_dataset = (torch.from_numpy(numpy_data),)
train_data = TensorDataset(*train_tensor_dataset) train_data = TensorDataset(*train_tensor_dataset)
train_sampler = RandomSampler(train_data) train_sampler = RandomSampler(train_data)
eval_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) eval_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)
# Compute head entropy and importance score # Compute head entropy and importance score
compute_heads_importance(args, model, eval_dataloader) compute_heads_importance(args, model, eval_dataloader)
# Try head masking (set heads to zero until the score goes under a threshole) # Try head masking (set heads to zero until the score goes under a threshole)
# and head pruning (remove masked heads and see the effect on the network) # and head pruning (remove masked heads and see the effect on the network)
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0: if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
head_mask = mask_heads(args, model, eval_dataloader) head_mask = mask_heads(args, model, eval_dataloader)
prune_heads(args, model, eval_dataloader, head_mask) prune_heads(args, model, eval_dataloader, head_mask)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,290 +1,290 @@
# coding=utf-8 # coding=utf-8
# Copyright 2018 Microsoft Authors and the HuggingFace Inc. team. # Copyright 2018 Microsoft Authors and the HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import unittest import unittest
import numpy as np import numpy as np
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
DebertaConfig, DebertaConfig,
DebertaForMaskedLM, DebertaForMaskedLM,
DebertaForQuestionAnswering, DebertaForQuestionAnswering,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification, DebertaForTokenClassification,
DebertaModel, DebertaModel,
) )
from transformers.models.deberta.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.deberta.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
@require_torch @require_torch
class DebertaModelTest(ModelTesterMixin, unittest.TestCase): class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
DebertaModel, DebertaModel,
DebertaForMaskedLM, DebertaForMaskedLM,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification, DebertaForTokenClassification,
DebertaForQuestionAnswering, DebertaForQuestionAnswering,
) )
if is_torch_available() if is_torch_available()
else () else ()
) )
test_torchscript = False test_torchscript = False
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
is_encoder_decoder = False is_encoder_decoder = False
class DebertaModelTester(object): class DebertaModelTester(object):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
relative_attention=False, relative_attention=False,
position_biased_input=True, position_biased_input=True,
pos_att_type="None", pos_att_type="None",
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.relative_attention = relative_attention self.relative_attention = relative_attention
self.position_biased_input = position_biased_input self.position_biased_input = position_biased_input
self.pos_att_type = pos_att_type self.pos_att_type = pos_att_type
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DebertaConfig( config = DebertaConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size, intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act, hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob, hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
relative_attention=self.relative_attention, relative_attention=self.relative_attention,
position_biased_input=self.position_biased_input, position_biased_input=self.position_biased_input,
pos_att_type=self.pos_att_type, pos_att_type=self.pos_att_type,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual(list(result.loss.size()), []) self.parent.assertListEqual(list(result.loss.size()), [])
def create_and_check_deberta_model( def create_and_check_deberta_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = DebertaModel(config=config) model = DebertaModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
sequence_output = model(input_ids, token_type_ids=token_type_ids)[0] sequence_output = model(input_ids, token_type_ids=token_type_ids)[0]
sequence_output = model(input_ids)[0] sequence_output = model(input_ids)[0]
self.parent.assertListEqual( self.parent.assertListEqual(
list(sequence_output.size()), [self.batch_size, self.seq_length, self.hidden_size] list(sequence_output.size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
def create_and_check_deberta_for_masked_lm( def create_and_check_deberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = DebertaForMaskedLM(config=config) model = DebertaForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_deberta_for_sequence_classification( def create_and_check_deberta_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = DebertaForSequenceClassification(config) model = DebertaForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result.logits.size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result.logits.size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_deberta_for_token_classification( def create_and_check_deberta_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = DebertaForTokenClassification(config=config) model = DebertaForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_deberta_for_question_answering( def create_and_check_deberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = DebertaForQuestionAnswering(config=config) model = DebertaForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_mask, input_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
self.model_tester = DebertaModelTest.DebertaModelTester(self) self.model_tester = DebertaModelTest.DebertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_deberta_model(self): def test_deberta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_model(*config_and_inputs) self.model_tester.create_and_check_deberta_model(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DebertaModel.from_pretrained(model_name) model = DebertaModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class DebertaModelIntegrationTest(unittest.TestCase): class DebertaModelIntegrationTest(unittest.TestCase):
@unittest.skip(reason="Model not available yet") @unittest.skip(reason="Model not available yet")
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
pass pass
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
model = DebertaModel.from_pretrained("microsoft/deberta-base") model = DebertaModel.from_pretrained("microsoft/deberta-base")
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0] output = model(input_ids)[0]
# compare the actual values for a slice. # compare the actual values for a slice.
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]] [[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]]
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}") self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}")