mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[tf/flax] handle forced_decoder_ids
deletion (#38316)
fix tf/flax, attr checks
This commit is contained in:
parent
9eb0a37c9e
commit
3e960e032d
@ -531,13 +531,16 @@ class FlaxGenerationMixin:
|
||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||
else begin_index + 1
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
|
||||
if (
|
||||
getattr(generation_config, "forced_decoder_ids", None) is not None
|
||||
and len(generation_config.forced_decoder_ids) > 0
|
||||
):
|
||||
# generation starts after the last token that is forced
|
||||
begin_index += generation_config.forced_decoder_ids[-1][0]
|
||||
processors.append(
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
forced_decoder_ids = [
|
||||
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
||||
]
|
||||
|
@ -1490,14 +1490,14 @@ class TFGenerationMixin:
|
||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||
else begin_index + 1
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
begin_index += generation_config.forced_decoder_ids[-1][
|
||||
0
|
||||
] # generation starts after the last token that is forced
|
||||
processors.append(
|
||||
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
|
Loading…
Reference in New Issue
Block a user