mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Beam search type (#24288)
* test check in * adding in type hint fix on beam search * fixed code quality issue
This commit is contained in:
parent
1a113fcf65
commit
e45bc14350
@ -15,7 +15,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -211,7 +211,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
||||
batch_size = len(self._beam_hyps)
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
|
Loading…
Reference in New Issue
Block a user