mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
🚨🚨 Generate: change order of ops in beam sample to avoid nans (#26843)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
0b8604d002
commit
4b423e6074
@ -1430,14 +1430,22 @@ class TFGenerationMixin:
|
||||
# instantiate warpers list
|
||||
warpers = TFLogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(generation_config.eos_token_id) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config.eos_token_id, list):
|
||||
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
|
||||
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
|
||||
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
@ -2366,14 +2374,11 @@ class TFGenerationMixin:
|
||||
log_probs = tf.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
|
||||
log_probs = unflatten_beam_dim(log_probs, num_beams)
|
||||
log_probs_processed = log_probs
|
||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||
if do_sample:
|
||||
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits
|
||||
# warpers (like top_p) this is indiferent, but on others (like temperature) it is not. For reference,
|
||||
# see https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
|
||||
log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
|
||||
log_probs = unflatten_beam_dim(log_probs, num_beams)
|
||||
log_probs_processed = log_probs
|
||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))
|
||||
|
||||
|
@ -820,11 +820,20 @@ class GenerationMixin:
|
||||
# instantiate warpers list
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config.eos_token_id, list):
|
||||
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
@ -3406,18 +3415,15 @@ class GenerationMixin:
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
|
||||
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
|
||||
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (logits_warper(input_ids, next_token_scores_processed),)
|
||||
scores += (next_token_scores_processed,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
|
Loading…
Reference in New Issue
Block a user