mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[s2s] run_eval supports --prefix clarg. (#6953)
This commit is contained in:
parent
563ffb3dc3
commit
b76cb1c3df
@ -36,6 +36,7 @@ def generate_summaries_or_translations(
|
|||||||
device: str = DEFAULT_DEVICE,
|
device: str = DEFAULT_DEVICE,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
task="summarization",
|
task="summarization",
|
||||||
|
prefix=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Save model.generate results to <out_file>, and return how long it took."""
|
"""Save model.generate results to <out_file>, and return how long it took."""
|
||||||
@ -51,9 +52,10 @@ def generate_summaries_or_translations(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# update config with task specific params
|
# update config with task specific params
|
||||||
use_task_specific_params(model, task)
|
use_task_specific_params(model, task)
|
||||||
|
if prefix is None:
|
||||||
|
prefix = prefix or getattr(model.config, "prefix", "") or ""
|
||||||
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
|
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
|
||||||
if "t5" in model_name:
|
examples_chunk = [prefix + text for text in examples_chunk]
|
||||||
examples_chunk = [model.config.prefix + text for text in examples_chunk]
|
|
||||||
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
|
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
|
||||||
summaries = model.generate(
|
summaries = model.generate(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
@ -78,6 +80,9 @@ def run_generate():
|
|||||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||||
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
||||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
|
||||||
|
)
|
||||||
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -103,6 +108,7 @@ def run_generate():
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
|
prefix=args.prefix,
|
||||||
**parsed,
|
**parsed,
|
||||||
)
|
)
|
||||||
if args.reference_path is None:
|
if args.reference_path is None:
|
||||||
|
@ -160,7 +160,7 @@ def test_opus_mt_distill_script():
|
|||||||
metrics = load_json(model.metrics_save_path)
|
metrics = load_json(model.metrics_save_path)
|
||||||
first_step_stats = metrics["val"][0]
|
first_step_stats = metrics["val"][0]
|
||||||
last_step_stats = metrics["val"][-1]
|
last_step_stats = metrics["val"][-1]
|
||||||
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
||||||
|
|
||||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user