mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Beam Search] Correct returned beam scores (#14654)
* better * save intermediate * finish code * up * docs * Apply suggestions from code review * up * add compute transition beam scores function to model and make sure scores are correct with eos * apply nicos comments * Apply suggestions from code review * another fix
This commit is contained in:
parent
e239fc3b0b
commit
8d6acc6c29
@ -208,10 +208,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
@ -223,6 +226,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
@ -241,10 +245,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
@ -267,6 +274,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
@ -286,10 +294,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
@ -301,6 +312,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
@ -319,10 +331,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
@ -343,6 +358,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
@ -743,6 +759,45 @@ class GenerationMixin:
|
||||
default_list.extend(custom_list)
|
||||
return default_list
|
||||
|
||||
def compute_transition_beam_scores(
|
||||
self,
|
||||
sequences: torch.Tensor,
|
||||
scores: Tuple[torch.Tensor],
|
||||
beam_indices: torch.Tensor,
|
||||
eos_token_id: int = None,
|
||||
):
|
||||
"""compute the transition probabilities of sequences given generation
|
||||
scores and beam indices"""
|
||||
|
||||
# reshape scores as [vocab_size * batch_size, # generation steps]
|
||||
# with batch_size being 2 * vocab_size and # generation steps being
|
||||
# seq_len - input_length
|
||||
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
||||
|
||||
# start of generated tokens
|
||||
cut_idx = sequences.shape[-1] - scores.shape[-1]
|
||||
# adjust for beam indices
|
||||
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
|
||||
# compute real indices
|
||||
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
||||
# gather scores and run
|
||||
transition_scores = scores.gather(0, indices)
|
||||
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
|
||||
# get first occurence of EOS token
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
if eos_token_id is not None:
|
||||
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
|
||||
# make sure first eos token still contributes to transition probs
|
||||
is_eos_token_id[:, -1] = False
|
||||
is_eos_token_id = is_eos_token_id.roll(1, -1)
|
||||
# all indices after eos shoud be masked
|
||||
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
|
||||
# zero out padded probs
|
||||
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
|
||||
|
||||
return transition_scores
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -1871,8 +1926,21 @@ class GenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
beam_indices = (
|
||||
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
@ -1884,16 +1952,6 @@ class GenerationMixin:
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
@ -1932,13 +1990,13 @@ class GenerationMixin:
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
scores += (next_token_scores_processed,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
@ -1973,6 +2031,7 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
@ -1985,6 +2044,9 @@ class GenerationMixin:
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
@ -2007,11 +2069,20 @@ class GenerationMixin:
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
@ -2023,6 +2094,7 @@ class GenerationMixin:
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
@ -2175,8 +2247,16 @@ class GenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
beam_indices = (
|
||||
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
@ -2188,11 +2268,6 @@ class GenerationMixin:
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
@ -2231,14 +2306,14 @@ class GenerationMixin:
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
scores += (logits_warper(input_ids, next_token_scores_processed),)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
@ -2289,6 +2364,9 @@ class GenerationMixin:
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
@ -2311,11 +2389,20 @@ class GenerationMixin:
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSampleEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
@ -2327,6 +2414,7 @@ class GenerationMixin:
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
@ -2472,6 +2560,24 @@ class GenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
num_beam_groups = beam_scorer.num_beam_groups
|
||||
num_sub_beams = num_beams // num_beam_groups
|
||||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
||||
else:
|
||||
beam_indices = None
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
@ -2485,19 +2591,6 @@ class GenerationMixin:
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
num_beam_groups = beam_scorer.num_beam_groups
|
||||
num_sub_beams = num_beams // num_beam_groups
|
||||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||
# the same group don't produce same tokens everytime.
|
||||
@ -2564,15 +2657,14 @@ class GenerationMixin:
|
||||
) # (batch_size * group_size, vocab_size)
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
|
||||
next_token_scores = logits_processor(
|
||||
next_token_scores_processed = logits_processor(
|
||||
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||
)
|
||||
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
|
||||
next_token_scores
|
||||
)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
||||
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
||||
|
||||
if output_scores:
|
||||
processed_score[batch_group_indices] = next_token_scores
|
||||
processed_score[batch_group_indices] = next_token_scores_processed
|
||||
|
||||
# reshape for beam search
|
||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||
@ -2597,6 +2689,11 @@ class GenerationMixin:
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices[beam_group_idx] = tuple(
|
||||
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
|
||||
)
|
||||
|
||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||
@ -2655,11 +2752,21 @@ class GenerationMixin:
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
beam_indices = sum(beam_indices, ())
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
|
@ -1903,3 +1903,147 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||
|
||||
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_decoder_only(self):
|
||||
articles = [
|
||||
"Justin Timberlake",
|
||||
"Michael Phelps",
|
||||
]
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-gpt2",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_sample_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
do_sample=True,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_group_beam_search_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=2,
|
||||
num_beam_groups=2,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
Loading…
Reference in New Issue
Block a user