diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a6231e79f39..50bd825249e 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -979,9 +979,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache - if isinstance(past_key_values, EncoderDecoderCache) is not None - else None, + past_key_values.self_attention_cache if isinstance(past_key_values, EncoderDecoderCache) else None, output_attentions, ) else: