[generation] consistently add eos tokens (#6982)

Currently beam search returns inconsistent outputs - if hypos have different lengths we get eos, if they are the same - we don't.

This PR makes the output consistent.

Also why not also replace:

```
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
```
with:
```
            decoded[i, sent_lengths[i]] = eos_token_id
```
Shouldn't eos always be there? If the data gets truncated, the caller needs to user a larger `max_length`.

Please correct me if my logic is flawed.
This commit is contained in:
Stas Bekman 2020-09-09 01:08:36 -07:00 committed by GitHub
parent d0963486c1
commit 03e363f9ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -841,21 +841,19 @@ class GenerationMixin:
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are padded
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded