[s2s] run_eval supports --prefix clarg. (#6953)

This commit is contained in:
Sam Shleifer 2020-09-12 01:08:21 -04:00 committed by GitHub
parent 563ffb3dc3
commit b76cb1c3df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -36,6 +36,7 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
prefix=None,
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
@ -51,9 +52,10 @@ def generate_summaries_or_translations(
start_time = time.time()
# update config with task specific params
use_task_specific_params(model, task)
if prefix is None:
prefix = prefix or getattr(model.config, "prefix", "") or ""
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
examples_chunk = [model.config.prefix + text for text in examples_chunk]
examples_chunk = [prefix + text for text in examples_chunk]
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
summaries = model.generate(
input_ids=batch.input_ids,
@ -78,6 +80,9 @@ def run_generate():
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument(
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
)
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
@ -103,6 +108,7 @@ def run_generate():
device=args.device,
fp16=args.fp16,
task=args.task,
prefix=args.prefix,
**parsed,
)
if args.reference_path is None:

View File

@ -160,7 +160,7 @@ def test_opus_mt_distill_script():
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01