mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
bug fix
This commit is contained in:
parent
3ca6a2a8ca
commit
c99eb5a8dc
@ -282,16 +282,20 @@ class BeamSearchScorer(BeamScorer):
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
batch_size = len(self._beam_hyps)
|
||||
final_beam_scores = final_beam_scores.view((batch_size, self.num_beams))
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
continue
|
||||
|
||||
batch_beam_scores = final_beam_scores[batch_idx, :]
|
||||
_, beam_ids = torch.sort(batch_beam_scores, descending=True)
|
||||
|
||||
# need to add best num_beams hypotheses to generated hyps
|
||||
for beam_id in range(self.num_beams):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
for beam_id in beam_ids:
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id.item()
|
||||
final_score = batch_beam_scores[beam_id.item()].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user