mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[tf/flax] handle forced_decoder_ids
deletion (#38316)
fix tf/flax, attr checks
This commit is contained in:
parent
9eb0a37c9e
commit
3e960e032d
@ -531,13 +531,16 @@ class FlaxGenerationMixin:
|
|||||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||||
else begin_index + 1
|
else begin_index + 1
|
||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
|
if (
|
||||||
|
getattr(generation_config, "forced_decoder_ids", None) is not None
|
||||||
|
and len(generation_config.forced_decoder_ids) > 0
|
||||||
|
):
|
||||||
# generation starts after the last token that is forced
|
# generation starts after the last token that is forced
|
||||||
begin_index += generation_config.forced_decoder_ids[-1][0]
|
begin_index += generation_config.forced_decoder_ids[-1][0]
|
||||||
processors.append(
|
processors.append(
|
||||||
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||||
forced_decoder_ids = [
|
forced_decoder_ids = [
|
||||||
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
||||||
]
|
]
|
||||||
|
@ -1490,14 +1490,14 @@ class TFGenerationMixin:
|
|||||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||||
else begin_index + 1
|
else begin_index + 1
|
||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||||
begin_index += generation_config.forced_decoder_ids[-1][
|
begin_index += generation_config.forced_decoder_ids[-1][
|
||||||
0
|
0
|
||||||
] # generation starts after the last token that is forced
|
] # generation starts after the last token that is forced
|
||||||
processors.append(
|
processors.append(
|
||||||
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||||
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||||
|
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
|
Loading…
Reference in New Issue
Block a user