mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
🚨🚨 Fix group beam search (#24407)
* group_beam_search now works correctly * add argument descriptions * add a comment * format * make style * change comment * Update src/transformers/generation/beam_search.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> --------- Co-authored-by: shogo.fujita <shogo.fujita@legalontech.jp> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
68c92981ff
commit
43479ef98f
@ -43,6 +43,10 @@ PROCESS_INPUTS_DOCSTRING = r"""
|
|||||||
The id of the *padding* token.
|
The id of the *padding* token.
|
||||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
beam_indices (`torch.LongTensor]`, *optional*):
|
||||||
|
Beam indices indicating to which beam hypothesis each token correspond.
|
||||||
|
group_index (`int`, *optional*):
|
||||||
|
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`UserDict`: A dictionary composed of the fields as defined above:
|
`UserDict`: A dictionary composed of the fields as defined above:
|
||||||
@ -175,16 +179,22 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
self.group_size = self.num_beams // self.num_beam_groups
|
self.group_size = self.num_beams // self.num_beam_groups
|
||||||
|
|
||||||
self._is_init = False
|
self._is_init = False
|
||||||
|
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
|
||||||
|
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
|
||||||
self._beam_hyps = [
|
self._beam_hyps = [
|
||||||
BeamHypotheses(
|
BeamHypotheses(
|
||||||
num_beams=self.num_beams,
|
num_beams=self.group_size,
|
||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
early_stopping=self.do_early_stopping,
|
early_stopping=self.do_early_stopping,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
for _ in range(batch_size)
|
for _ in range(batch_size * self.num_beam_groups)
|
||||||
]
|
]
|
||||||
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
|
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
|
||||||
|
# in the i-th mini-batch is complete.
|
||||||
|
self._done = torch.tensor(
|
||||||
|
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -211,9 +221,11 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
|
group_index: Optional[int] = 0,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||||
|
|
||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||||
if self.num_beam_groups > 1:
|
if self.num_beam_groups > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -234,9 +246,10 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_idx in range(batch_size):
|
||||||
if self._done[batch_idx]:
|
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
||||||
if self.num_beams < len(beam_hyp):
|
if self._done[batch_group_idx]:
|
||||||
|
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
|
||||||
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
||||||
if eos_token_id is None or pad_token_id is None:
|
if eos_token_id is None or pad_token_id is None:
|
||||||
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
||||||
@ -264,7 +277,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
else:
|
else:
|
||||||
beam_index = None
|
beam_index = None
|
||||||
|
|
||||||
beam_hyp.add(
|
self._beam_hyps[batch_group_idx].add(
|
||||||
input_ids[batch_beam_idx].clone(),
|
input_ids[batch_beam_idx].clone(),
|
||||||
next_score.item(),
|
next_score.item(),
|
||||||
beam_indices=beam_index,
|
beam_indices=beam_index,
|
||||||
@ -287,7 +300,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if we are done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
||||||
next_scores[batch_idx].max().item(), cur_len
|
next_scores[batch_idx].max().item(), cur_len
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,20 +323,20 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
|
||||||
# finalize all open beam hypotheses and add to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||||
if self._done[batch_idx]:
|
if self._done[batch_group_idx]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# all open beam hypotheses are added to the beam hypothesis
|
# all open beam hypotheses are added to the beam hypothesis
|
||||||
# beam hypothesis class automatically keeps the best beams
|
# beam hypothesis class automatically keeps the best beams
|
||||||
for beam_id in range(self.num_beams):
|
for index_per_group in range(self.group_size):
|
||||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
|
||||||
final_score = final_beam_scores[batch_beam_idx].item()
|
final_score = final_beam_scores[batch_beam_idx].item()
|
||||||
final_tokens = input_ids[batch_beam_idx]
|
final_tokens = input_ids[batch_beam_idx]
|
||||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||||
@ -336,8 +349,10 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
# retrieve best hypotheses
|
# retrieve best hypotheses
|
||||||
for i, beam_hyp in enumerate(self._beam_hyps):
|
for i in range(batch_size):
|
||||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
|
||||||
|
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
|
||||||
|
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
|
||||||
for j in range(self.num_beam_hyps_to_keep):
|
for j in range(self.num_beam_hyps_to_keep):
|
||||||
best_hyp_tuple = sorted_hyps.pop()
|
best_hyp_tuple = sorted_hyps.pop()
|
||||||
best_score = best_hyp_tuple[0]
|
best_score = best_hyp_tuple[0]
|
||||||
|
@ -3522,10 +3522,10 @@ class GenerationMixin:
|
|||||||
else self.generation_config.return_dict_in_generate
|
else self.generation_config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
num_beam_groups = beam_scorer.num_beam_groups
|
num_beam_groups = beam_scorer.num_beam_groups
|
||||||
num_sub_beams = num_beams // num_beam_groups
|
num_sub_beams = num_beams // num_beam_groups
|
||||||
|
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
@ -3648,6 +3648,7 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
beam_indices=process_beam_indices,
|
beam_indices=process_beam_indices,
|
||||||
|
group_index=beam_group_idx,
|
||||||
)
|
)
|
||||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
|
Loading…
Reference in New Issue
Block a user