diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index ddd718cbb8a..3c8c4795a84 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -531,13 +531,16 @@ class FlaxGenerationMixin: if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 ) - if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: + if ( + getattr(generation_config, "forced_decoder_ids", None) is not None + and len(generation_config.forced_decoder_ids) > 0 + ): # generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: forced_decoder_ids = [ [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids ] diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 510186cafc0..ae77f32e269 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1490,14 +1490,14 @@ class TFGenerationMixin: if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: begin_index += generation_config.forced_decoder_ids[-1][ 0 ] # generation starts after the last token that is forced processors.append( TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors = self._merge_criteria_processor_list(processors, logits_processor)