mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
seq2seq/run_eval.py can take decoder_start_token_id (#5949)
This commit is contained in:
parent
5b193b39b0
commit
9dab39feea
@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||
if isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.dataset_class = MBartDataset
|
||||
|
||||
|
@ -30,6 +30,7 @@ def generate_summaries_or_translations(
|
||||
device: str = DEFAULT_DEVICE,
|
||||
fp16=False,
|
||||
task="summarization",
|
||||
decoder_start_token_id=None,
|
||||
**gen_kwargs,
|
||||
) -> None:
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
@ -37,6 +38,8 @@ def generate_summaries_or_translations(
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
if fp16:
|
||||
model = model.half()
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
@ -48,7 +51,12 @@ def generate_summaries_or_translations(
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
|
||||
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
|
||||
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
||||
summaries = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**gen_kwargs,
|
||||
)
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
@ -66,6 +74,13 @@ def run_generate():
|
||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument(
|
||||
"--decoder_start_token_id",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="decoder_start_token_id (otherwise will look at config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
||||
)
|
||||
@ -83,6 +98,7 @@ def run_generate():
|
||||
device=args.device,
|
||||
fp16=args.fp16,
|
||||
task=args.task,
|
||||
decoder_start_token_id=args.decoder_start_token_id,
|
||||
)
|
||||
if args.reference_path is None:
|
||||
return
|
||||
|
@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
|
||||
return [self.decode(seq, **kwargs) for seq in sequences]
|
||||
def batch_decode(
|
||||
self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
|
||||
skip_special_tokens: if set to True, will replace special tokens.
|
||||
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
|
||||
"""
|
||||
return [
|
||||
self.decode(
|
||||
seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces
|
||||
)
|
||||
for seq in sequences
|
||||
]
|
||||
|
||||
def decode(
|
||||
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||
|
Loading…
Reference in New Issue
Block a user