🚨🚨 Fix group beam search (#24407)

* group_beam_search now works correctly

* add argument descriptions

* add a comment

* format

* make style

* change comment

* Update src/transformers/generation/beam_search.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

---------

Co-authored-by: shogo.fujita <shogo.fujita@legalontech.jp>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
hukuda222 2023-06-27 18:43:10 +09:00 committed by GitHub
parent 68c92981ff
commit 43479ef98f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 17 deletions

View File

@ -43,6 +43,10 @@ PROCESS_INPUTS_DOCSTRING = r"""
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor]`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
Return:
`UserDict`: A dictionary composed of the fields as defined above:
@ -175,16 +179,22 @@ class BeamSearchScorer(BeamScorer):
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
for _ in range(batch_size * self.num_beam_groups)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
# in the i-th mini-batch is complete.
self._done = torch.tensor(
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
@ -211,9 +221,11 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
@ -234,9 +246,10 @@ class BeamSearchScorer(BeamScorer):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
if self._done[batch_group_idx]:
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
@ -264,7 +277,7 @@ class BeamSearchScorer(BeamScorer):
else:
beam_index = None
beam_hyp.add(
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
@ -287,7 +300,7 @@ class BeamSearchScorer(BeamScorer):
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)
@ -310,20 +323,20 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
batch_size = len(self._beam_hyps) // self.num_beam_groups
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_group_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
for index_per_group in range(self.group_size):
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
@ -336,8 +349,10 @@ class BeamSearchScorer(BeamScorer):
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for i in range(batch_size):
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]

View File

@ -3522,10 +3522,10 @@ class GenerationMixin:
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
@ -3648,6 +3648,7 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]