add past re-ordering for beam search

This commit is contained in:
patrickvonplaten 2019-12-25 16:29:20 +01:00
parent 6bca56fdb0
commit 90cda45e9e

View File

@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# copy the relevant beam idx past to past
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
past = tuple(reordered_past)
# update current length
cur_len = cur_len + 1