Beam search type (#24288)

* test check in

* adding in type hint fix on beam search

* fixed code quality issue
This commit is contained in:
jprivera44 2023-06-15 08:48:02 -07:00 committed by GitHub
parent 1a113fcf65
commit e45bc14350
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -211,7 +211,7 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):