mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add DeeBERT (entropy-based early exiting for *BERT) (#5477)
* Add deebert code * Add readme of deebert * Add test for deebert Update test for Deebert * Update DeeBert (README, class names, function refactoring); remove requirements.txt * Format update * Update test * Update readme and model init methods
This commit is contained in:
parent
b4b33fdf25
commit
cfbb982974
54
examples/deebert/README.md
Normal file
54
examples/deebert/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# DeeBERT: Early Exiting for *BERT
|
||||
|
||||
This is the code base for the paper [DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference](https://www.aclweb.org/anthology/2020.acl-main.204/), modified from its [original code base](https://github.com/castorini/deebert).
|
||||
|
||||
The original code base also has information for downloading sample models that we have trained in advance.
|
||||
|
||||
## Usage
|
||||
|
||||
There are three scripts in the folder which can be run directly.
|
||||
|
||||
In each script, there are several things to modify before running:
|
||||
|
||||
* `PATH_TO_DATA`: path to the GLUE dataset.
|
||||
* `--output_dir`: path for saving fine-tuned models. Default: `./saved_models`.
|
||||
* `--plot_data_dir`: path for saving evaluation results. Default: `./results`. Results are printed to stdout and also saved to `npy` files in this directory to facilitate plotting figures and further analyses.
|
||||
* `MODEL_TYPE`: bert or roberta
|
||||
* `MODEL_SIZE`: base or large
|
||||
* `DATASET`: SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
#### train_deebert.sh
|
||||
|
||||
This is for fine-tuning DeeBERT models.
|
||||
|
||||
#### eval_deebert.sh
|
||||
|
||||
This is for evaluating each exit layer for fine-tuned DeeBERT models.
|
||||
|
||||
#### entropy_eval.sh
|
||||
|
||||
This is for evaluating fine-tuned DeeBERT models, given a number of different early exit entropy thresholds.
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
Please cite our paper if you find the resource useful:
|
||||
```
|
||||
@inproceedings{xin-etal-2020-deebert,
|
||||
title = "{D}ee{BERT}: Dynamic Early Exiting for Accelerating {BERT} Inference",
|
||||
author = "Xin, Ji and
|
||||
Tang, Raphael and
|
||||
Lee, Jaejun and
|
||||
Yu, Yaoliang and
|
||||
Lin, Jimmy",
|
||||
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
|
||||
month = jul,
|
||||
year = "2020",
|
||||
address = "Online",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/2020.acl-main.204",
|
||||
pages = "2246--2251",
|
||||
}
|
||||
```
|
||||
|
33
examples/deebert/entropy_eval.sh
Executable file
33
examples/deebert/entropy_eval.sh
Executable file
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
ENTROPIES="0 0.1 0.2 0.3 0.4 0.5 0.6 0.7"
|
||||
|
||||
for ENTROPY in $ENTROPIES; do
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--early_exit_entropy $ENTROPY \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
||||
done
|
30
examples/deebert/eval_deebert.sh
Executable file
30
examples/deebert/eval_deebert.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--task_name $DATASET \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--max_seq_length 128 \
|
||||
--eval_each_highway \
|
||||
--eval_highway \
|
||||
--overwrite_cache \
|
||||
--per_gpu_eval_batch_size=1
|
720
examples/deebert/run_glue_deebert.py
Normal file
720
examples/deebert/run_glue_deebert.py
Normal file
@ -0,0 +1,720 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from src.modeling_highway_bert import DeeBertForSequenceClassification
|
||||
from src.modeling_highway_roberta import DeeRobertaForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, DeeBertForSequenceClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, DeeRobertaForSequenceClassification, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def get_wanted_result(result):
|
||||
if "spearmanr" in result:
|
||||
print_result = result["spearmanr"]
|
||||
elif "f1" in result:
|
||||
print_result = result["f1"]
|
||||
elif "mcc" in result:
|
||||
print_result = result["mcc"]
|
||||
elif "acc" in result:
|
||||
print_result = result["acc"]
|
||||
else:
|
||||
raise ValueError("Primary metric unclear in the results")
|
||||
return print_result
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
if train_highway:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters() if ("highway" in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
else:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (not any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if ("highway" not in n) and (any(nd in n for nd in no_decay))
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
inputs["train_highway"] = train_highway
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=False):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
exit_layer_counter = {(i + 1): 0 for i in range(model.num_layers)}
|
||||
st = time.time()
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
if output_layer >= 0:
|
||||
inputs["output_layer"] = output_layer
|
||||
outputs = model(**inputs)
|
||||
if eval_highway:
|
||||
exit_layer_counter[outputs[-1]] += 1
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
eval_time = time.time() - st
|
||||
logger.info("Eval time: {}".format(eval_time))
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
if eval_highway:
|
||||
logger.info("Exit layer counter: {}".format(exit_layer_counter))
|
||||
actual_cost = sum([l * c for l, c in exit_layer_counter.items()])
|
||||
full_cost = len(eval_dataloader) * model.num_layers
|
||||
logger.info("Expected saving: {}".format(actual_cost / full_cost))
|
||||
if args.early_exit_entropy >= 0:
|
||||
save_fname = (
|
||||
args.plot_data_dir
|
||||
+ "/"
|
||||
+ args.model_name_or_path[2:]
|
||||
+ "/entropy_{}.npy".format(args.early_exit_entropy)
|
||||
)
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
print_result = get_wanted_result(result)
|
||||
np.save(save_fname, np.array([exit_layer_counter, eval_time, actual_cost / full_cost, print_result]))
|
||||
logger.info("Entropy={}\tResult={:.2f}".format(args.early_exit_entropy, 100 * print_result))
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples, tokenizer, label_list=label_list, max_length=args.max_seq_length, output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
|
||||
if features[0].token_type_ids is None:
|
||||
# For RoBERTa (a potential bug!)
|
||||
all_token_type_ids = torch.tensor([[0] * args.max_seq_length for f in features], dtype=torch.long)
|
||||
else:
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot_data_dir",
|
||||
default="./plotting/",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The directory to store data for plotting figures.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
parser.add_argument("--eval_each_highway", action="store_true", help="Set this flag to evaluate each highway.")
|
||||
parser.add_argument(
|
||||
"--eval_after_first_stage",
|
||||
action="store_true",
|
||||
help="Set this flag to evaluate after training only bert (not highway).",
|
||||
)
|
||||
parser.add_argument("--eval_highway", action="store_true", help="Set this flag if it's evaluating highway models")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--early_exit_entropy", default=-1, type=float, help="Entropy threshold for early exit.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.bert.init_highway_pooler()
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
model.roberta.init_highway_pooler()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if args.eval_after_first_stage:
|
||||
result = evaluate(args, model, tokenizer, prefix="")
|
||||
print_result = get_wanted_result(result)
|
||||
|
||||
train(args, train_dataset, model, tokenizer, train_highway=True)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
if args.model_type == "bert":
|
||||
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
elif args.model_type == "roberta":
|
||||
model.roberta.encoder.set_early_exit_entropy(args.early_exit_entropy)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, eval_highway=args.eval_highway)
|
||||
print_result = get_wanted_result(result)
|
||||
logger.info("Result: {}".format(print_result))
|
||||
if args.eval_each_highway:
|
||||
last_layer_results = print_result
|
||||
each_layer_results = []
|
||||
for i in range(model.num_layers):
|
||||
logger.info("\n")
|
||||
_result = evaluate(
|
||||
args, model, tokenizer, prefix=prefix, output_layer=i, eval_highway=args.eval_highway
|
||||
)
|
||||
if i + 1 < model.num_layers:
|
||||
each_layer_results.append(get_wanted_result(_result))
|
||||
each_layer_results.append(last_layer_results)
|
||||
save_fname = args.plot_data_dir + "/" + args.model_name_or_path[2:] + "/each_layer.npy"
|
||||
if not os.path.exists(os.path.dirname(save_fname)):
|
||||
os.makedirs(os.path.dirname(save_fname))
|
||||
np.save(save_fname, np.array(each_layer_results))
|
||||
info_str = "Score of each layer:"
|
||||
for i in range(model.num_layers):
|
||||
info_str += " {:.2f}".format(100 * each_layer_results[i])
|
||||
logger.info(info_str)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
examples/deebert/src/__init__.py
Normal file
0
examples/deebert/src/__init__.py
Normal file
396
examples/deebert/src/modeling_highway_bert.py
Normal file
396
examples/deebert/src/modeling_highway_bert.py
Normal file
@ -0,0 +1,396 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
BERT_START_DOCSTRING,
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertPooler,
|
||||
BertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
def entropy(x):
|
||||
""" Calculate entropy of a pre-softmax logit Tensor
|
||||
"""
|
||||
exp_x = torch.exp(x)
|
||||
A = torch.sum(exp_x, dim=1) # sum of exp(x_i)
|
||||
B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i)
|
||||
return torch.log(A) - B / A
|
||||
|
||||
|
||||
class DeeBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def set_early_exit_entropy(self, x):
|
||||
if (type(x) is float) or (type(x) is int):
|
||||
for i in range(len(self.early_exit_entropy)):
|
||||
self.early_exit_entropy[i] = x
|
||||
else:
|
||||
self.early_exit_entropy = x
|
||||
|
||||
def init_highway_pooler(self, pooler):
|
||||
loaded_model = pooler.state_dict()
|
||||
for highway in self.highway:
|
||||
for name, param in highway.pooler.state_dict().items():
|
||||
param.copy_(loaded_model[name])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
all_highway_exits = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
current_outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
current_outputs = current_outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
current_outputs = current_outputs + (all_attentions,)
|
||||
|
||||
highway_exit = self.highway[i](current_outputs)
|
||||
# logits, pooled_output
|
||||
|
||||
if not self.training:
|
||||
highway_logits = highway_exit[0]
|
||||
highway_entropy = entropy(highway_logits)
|
||||
highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
if highway_entropy < self.early_exit_entropy[i]:
|
||||
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,)
|
||||
raise HighwayException(new_output, i + 1)
|
||||
else:
|
||||
all_highway_exits = all_highway_exits + (highway_exit,)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
outputs = outputs + (all_highway_exits,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The Bert Model transformer with early exiting (DeeBERT). ", BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = DeeBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_highway_pooler(self):
|
||||
self.encoder.init_highway_pooler(self.pooler)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during pre-training.
|
||||
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
|
||||
class HighwayException(Exception):
|
||||
def __init__(self, message, exit_layer):
|
||||
self.message = message
|
||||
self.exit_layer = exit_layer # start from 1!
|
||||
|
||||
|
||||
class BertHighway(nn.Module):
|
||||
"""A module to provide a shortcut
|
||||
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.pooler = BertPooler(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, encoder_outputs):
|
||||
# Pooler
|
||||
pooler_input = encoder_outputs[0]
|
||||
pooler_output = self.pooler(pooler_input)
|
||||
# "return" pooler_output
|
||||
|
||||
# BertModel
|
||||
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:]
|
||||
# "return" bodel_output
|
||||
|
||||
# Dropout and classification
|
||||
pooled_output = bmodel_output[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
return logits, pooled_output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model (with early exiting - DeeBERT) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class DeeBertForSequenceClassification(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.bert = DeeBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)
|
151
examples/deebert/src/modeling_highway_roberta.py
Normal file
151
examples/deebert/src/modeling_highway_roberta.py
Normal file
@ -0,0 +1,151 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_roberta import ROBERTA_INPUTS_DOCSTRING, ROBERTA_START_DOCSTRING, RobertaEmbeddings
|
||||
|
||||
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The RoBERTa Model transformer with early exiting (DeeRoBERTa). ", ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaModel(DeeBertModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.init_weights()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""RoBERTa Model (with early exiting - DeeRoBERTa) with a classifier on top,
|
||||
also takes care of multi-layer training. """,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.roberta = DeeRobertaModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_layer=-1,
|
||||
train_highway=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
|
||||
Tuple of each early exit's results (total length: number of layers)
|
||||
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
|
||||
"""
|
||||
|
||||
exit_layer = self.num_layers
|
||||
try:
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
except HighwayException as e:
|
||||
outputs = e.message
|
||||
exit_layer = e.exit_layer
|
||||
logits = outputs[0]
|
||||
|
||||
if not self.training:
|
||||
original_entropy = entropy(logits)
|
||||
highway_entropy = []
|
||||
highway_logits_all = []
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
# work with highway exits
|
||||
highway_losses = []
|
||||
for highway_exit in outputs[-1]:
|
||||
highway_logits = highway_exit[0]
|
||||
if not self.training:
|
||||
highway_logits_all.append(highway_logits)
|
||||
highway_entropy.append(highway_exit[2])
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
highway_losses.append(highway_loss)
|
||||
|
||||
if train_highway:
|
||||
outputs = (sum(highway_losses[:-1]),) + outputs
|
||||
# exclude the final highway, of course
|
||||
else:
|
||||
outputs = (loss,) + outputs
|
||||
if not self.training:
|
||||
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
|
||||
if output_layer >= 0:
|
||||
outputs = (
|
||||
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
|
||||
) # use the highway of the last layer
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions), entropy
|
97
examples/deebert/test_glue_deebert.py
Normal file
97
examples/deebert/test_glue_deebert.py
Normal file
@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_deebert
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
class DeeBertTests(unittest.TestCase):
|
||||
def test_glue_deebert(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
train_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path roberta-base
|
||||
--task_name MRPC
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--max_seq_length 128
|
||||
--per_gpu_eval_batch_size=1
|
||||
--per_gpu_train_batch_size=8
|
||||
--learning_rate 2e-4
|
||||
--num_train_epochs 3
|
||||
--overwrite_output_dir
|
||||
--seed 42
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--save_steps 0
|
||||
--overwrite_cache
|
||||
--eval_after_first_stage
|
||||
""".split()
|
||||
|
||||
eval_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--eval_each_highway
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
|
||||
entropy_eval_args = """
|
||||
run_glue_deebert.py
|
||||
--model_type roberta
|
||||
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--task_name MRPC
|
||||
--do_eval
|
||||
--do_lower_case
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
|
||||
--plot_data_dir ./examples/deebert/results/
|
||||
--max_seq_length 128
|
||||
--early_exit_entropy 0.1
|
||||
--eval_highway
|
||||
--overwrite_cache
|
||||
--per_gpu_eval_batch_size=1
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", train_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
|
||||
with patch.object(sys, "argv", eval_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
|
||||
with patch.object(sys, "argv", entropy_eval_args):
|
||||
result = run_glue_deebert.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
38
examples/deebert/train_deebert.sh
Executable file
38
examples/deebert/train_deebert.sh
Executable file
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
PATH_TO_DATA=/h/xinji/projects/GLUE
|
||||
|
||||
MODEL_TYPE=bert # bert or roberta
|
||||
MODEL_SIZE=base # base or large
|
||||
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
|
||||
|
||||
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
|
||||
EPOCHS=10
|
||||
if [ $MODEL_TYPE = 'bert' ]
|
||||
then
|
||||
EPOCHS=3
|
||||
MODEL_NAME=${MODEL_NAME}-uncased
|
||||
fi
|
||||
|
||||
|
||||
python -u run_glue_deebert.py \
|
||||
--model_type $MODEL_TYPE \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--task_name $DATASET \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $PATH_TO_DATA/$DATASET \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size=1 \
|
||||
--per_gpu_train_batch_size=8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs $EPOCHS \
|
||||
--overwrite_output_dir \
|
||||
--seed 42 \
|
||||
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
|
||||
--plot_data_dir ./results/ \
|
||||
--save_steps 0 \
|
||||
--overwrite_cache \
|
||||
--eval_after_first_stage
|
Loading…
Reference in New Issue
Block a user