[tf/flax] handle forced_decoder_ids deletion (#38316)

fix tf/flax, attr checks
This commit is contained in:
Joao Gante 2025-05-23 10:44:58 +01:00 committed by GitHub
parent 9eb0a37c9e
commit 3e960e032d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -531,13 +531,16 @@ class FlaxGenerationMixin:
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
if (
getattr(generation_config, "forced_decoder_ids", None) is not None
and len(generation_config.forced_decoder_ids) > 0
):
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
if getattr(generation_config, "forced_decoder_ids", None) is not None:
forced_decoder_ids = [
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]

View File

@ -1490,14 +1490,14 @@ class TFGenerationMixin:
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None:
if getattr(generation_config, "forced_decoder_ids", None) is not None:
begin_index += generation_config.forced_decoder_ids[-1][
0
] # generation starts after the last token that is forced
processors.append(
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
if getattr(generation_config, "forced_decoder_ids", None) is not None:
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)