Beam search type (#24288)

* test check in

* adding in type hint fix on beam search

* fixed code quality issue
This commit is contained in:
jprivera44 2023-06-15 08:48:02 -07:00 committed by GitHub
parent 1a113fcf65
commit e45bc14350
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,7 +15,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -211,7 +211,7 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None, beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]: ) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):