🚨🚨 Fix group beam search (#24407)

* group_beam_search now works correctly

* add argument descriptions

* add a comment

* format

* make style

* change comment

* Update src/transformers/generation/beam_search.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

---------

Co-authored-by: shogo.fujita <shogo.fujita@legalontech.jp>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
hukuda222 2023-06-27 18:43:10 +09:00 committed by GitHub
parent 68c92981ff
commit 43479ef98f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 17 deletions

View File

@ -43,6 +43,10 @@ PROCESS_INPUTS_DOCSTRING = r"""
The id of the *padding* token. The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*): eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor]`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
Return: Return:
`UserDict`: A dictionary composed of the fields as defined above: `UserDict`: A dictionary composed of the fields as defined above:
@ -175,16 +179,22 @@ class BeamSearchScorer(BeamScorer):
self.group_size = self.num_beams // self.num_beam_groups self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False self._is_init = False
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
self._beam_hyps = [ self._beam_hyps = [
BeamHypotheses( BeamHypotheses(
num_beams=self.num_beams, num_beams=self.group_size,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping, early_stopping=self.do_early_stopping,
max_length=max_length, max_length=max_length,
) )
for _ in range(batch_size) for _ in range(batch_size * self.num_beam_groups)
] ]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
# in the i-th mini-batch is complete.
self._done = torch.tensor(
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
)
if not isinstance(num_beams, int) or num_beams <= 1: if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError( raise ValueError(
@ -211,9 +221,11 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None, beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1: if self.num_beam_groups > 1:
raise ValueError( raise ValueError(
@ -234,9 +246,10 @@ class BeamSearchScorer(BeamScorer):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
for batch_idx, beam_hyp in enumerate(self._beam_hyps): for batch_idx in range(batch_size):
if self._done[batch_idx]: batch_group_idx = batch_idx * self.num_beam_groups + group_index
if self.num_beams < len(beam_hyp): if self._done[batch_group_idx]:
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None: if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
@ -264,7 +277,7 @@ class BeamSearchScorer(BeamScorer):
else: else:
beam_index = None beam_index = None
beam_hyp.add( self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(), input_ids[batch_beam_idx].clone(),
next_score.item(), next_score.item(),
beam_indices=beam_index, beam_indices=beam_index,
@ -287,7 +300,7 @@ class BeamSearchScorer(BeamScorer):
) )
# Check if we are done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len next_scores[batch_idx].max().item(), cur_len
) )
@ -310,20 +323,20 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None, beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]: ) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps) // self.num_beam_groups
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
# finalize all open beam hypotheses and add to generated hypotheses # finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps): for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]: if self._done[batch_group_idx]:
continue continue
# all open beam hypotheses are added to the beam hypothesis # all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams # beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams): for index_per_group in range(self.group_size):
batch_beam_idx = batch_idx * self.num_beams + beam_id batch_beam_idx = batch_group_idx * self.group_size + index_per_group
final_score = final_beam_scores[batch_beam_idx].item() final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx] final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
@ -336,8 +349,10 @@ class BeamSearchScorer(BeamScorer):
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses # retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps): for i in range(batch_size):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep): for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop() best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0] best_score = best_hyp_tuple[0]

View File

@ -3522,10 +3522,10 @@ class GenerationMixin:
else self.generation_config.return_dict_in_generate else self.generation_config.return_dict_in_generate
) )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
device = input_ids.device device = input_ids.device
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
@ -3648,6 +3648,7 @@ class GenerationMixin:
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
beam_indices=process_beam_indices, beam_indices=process_beam_indices,
group_index=beam_group_idx,
) )
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"] beam_next_tokens = beam_outputs["next_beam_tokens"]