mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Beam search type (#24288)
* test check in * adding in type hint fix on beam search * fixed code quality issue
This commit is contained in:
parent
1a113fcf65
commit
e45bc14350
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -211,7 +211,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||||
|
Loading…
Reference in New Issue
Block a user