This commit is contained in:
Ayush Jain 2020-12-07 19:23:39 +05:30
parent 3ca6a2a8ca
commit c99eb5a8dc

View File

@ -282,16 +282,20 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
batch_size = len(self._beam_hyps)
final_beam_scores = final_beam_scores.view((batch_size, self.num_beams))
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
batch_beam_scores = final_beam_scores[batch_idx, :]
_, beam_ids = torch.sort(batch_beam_scores, descending=True)
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
for beam_id in beam_ids:
batch_beam_idx = batch_idx * self.num_beams + beam_id.item()
final_score = batch_beam_scores[beam_id.item()].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)