mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
add past re-ordering for beam search
This commit is contained in:
parent
6bca56fdb0
commit
90cda45e9e
@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module):
|
||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
|
||||
# re-order batch and internal states
|
||||
# re-order batch
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||
# TODO: Activate cache
|
||||
# for k in cache.keys():
|
||||
# if k != 'slen':
|
||||
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# copy the relevant beam idx past to past
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
|
||||
past = tuple(reordered_past)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
Loading…
Reference in New Issue
Block a user