mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[seq2seq] distillation.py accepts trainer arguments (#5865)
This commit is contained in:
parent
ba2400189b
commit
dad5e12e54
@ -149,7 +149,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi
|
|||||||
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
|
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
|
||||||
```bash
|
```bash
|
||||||
export DATA_DIR=wmt_en_ro
|
export DATA_DIR=wmt_en_ro
|
||||||
python run_eval.py t5_base \
|
python run_eval.py t5-base \
|
||||||
$DATA_DIR/val.source t5_val_generations.txt \
|
$DATA_DIR/val.source t5_val_generations.txt \
|
||||||
--reference_path $DATA_DIR/val.target \
|
--reference_path $DATA_DIR/val.target \
|
||||||
--score_path enro_bleu.json \
|
--score_path enro_bleu.json \
|
||||||
|
@ -446,6 +446,7 @@ def distill_main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user