🚨🚨 Generate: change order of ops in beam sample to avoid nans (#26843)

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante 2023-10-17 10:32:49 +01:00 committed by GitHub
parent 0b8604d002
commit 4b423e6074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 15 deletions

View File

@ -1430,14 +1430,22 @@ class TFGenerationMixin:
# instantiate warpers list
warpers = TFLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(generation_config.eos_token_id) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
return warpers
def _get_logits_processor(
@ -2366,14 +2374,11 @@ class TFGenerationMixin:
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs_processed = log_probs
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
if do_sample:
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits
# warpers (like top_p) this is indiferent, but on others (like temperature) it is not. For reference,
# see https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs_processed = log_probs
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2]
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))

View File

@ -820,11 +820,20 @@ class GenerationMixin:
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
@ -3406,18 +3415,15 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logits_warper(input_ids, next_token_scores_processed),)
scores += (next_token_scores_processed,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)