mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[generation] consistently add eos tokens (#6982)
Currently beam search returns inconsistent outputs - if hypos have different lengths we get eos, if they are the same - we don't. This PR makes the output consistent. Also why not also replace: ``` if sent_lengths[i] < max_length: decoded[i, sent_lengths[i]] = eos_token_id ``` with: ``` decoded[i, sent_lengths[i]] = eos_token_id ``` Shouldn't eos always be there? If the data gets truncated, the caller needs to user a larger `max_length`. Please correct me if my logic is flawed.
This commit is contained in:
parent
d0963486c1
commit
03e363f9ae
@ -841,21 +841,19 @@ class GenerationMixin:
|
||||
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are padded
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
return decoded
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user