[s2s] distributed eval cleanup (#7186)

This commit is contained in:
Sam Shleifer 2020-09-16 15:38:37 -04:00 committed by GitHub
parent 3babef815c
commit 0203ad43bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 31 deletions

View File

@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--fp16 \
--bs 32
```
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
--model_name sshleifer/distilbart-large-xsum-12-3 \
--save_dir xsum_generations \
--data_dir xsum \
--fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py
```
Contributions that implement this command for other distributed hardware setups are welcome!
#### run_eval tips and tricks

View File

@ -4,7 +4,7 @@ import time
from json import JSONDecodeError
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
@ -22,7 +22,7 @@ try:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
@ -34,7 +34,7 @@ except ImportError:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
@ -50,7 +50,6 @@ def eval_data_dir(
type_path="val",
n_obs=None,
fp16=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
**generate_kwargs,
@ -81,23 +80,21 @@ def eval_data_dir(
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs,
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
ids = batch["ids"]
for i in range(len(labels)):
label, pred = labels[i], preds[i]
results.append(dict(pred=pred, label=label, id=ids[i].item()))
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
save_json(results, save_path)
return results, sampler.num_replicas
@ -139,8 +136,8 @@ def run_generate():
parser.add_argument("--debug", action="store_true")
start_time = time.time()
args, rest = parser.parse_known_args()
generate_kwargs = parse_numeric_cl_kwargs(rest)
if generate_kwargs:
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
if generate_kwargs and args.local_rank <= 0:
print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
@ -168,7 +165,10 @@ def run_generate():
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
preds, labels = combine_partial_results(partial_results)
preds = combine_partial_results(partial_results)
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge
@ -179,7 +179,7 @@ def run_generate():
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
save_json(metrics, metrics_save_path)
save_json(metrics, metrics_save_path, indent=None)
print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug:
@ -188,15 +188,14 @@ def run_generate():
shutil.rmtree(json_save_dir)
def combine_partial_results(partial_results) -> Tuple[List, List]:
def combine_partial_results(partial_results) -> List:
"""Concatenate partial results into one file, then sort it by id."""
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
return preds, labels
return preds
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:

View File

@ -156,7 +156,7 @@ def run_generate(verbose=True):
scores["info"] = args.info
if verbose:
print(*scores)
print(scores)
if args.score_path is not None:
path = args.score_path

View File

@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, **kwargs)
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size)
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs))
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int) -> np.array:
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i):
return data[i]
@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path):
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=4)
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):