diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 27311ddde07..6cc0ef7c994 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -732,6 +732,13 @@ class BloomModel(BloomPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # Compute alibi tensor: check build_alibi_tensor documentation seq_length_with_past = seq_length past_key_values_length = 0 @@ -756,11 +763,6 @@ class BloomModel(BloomPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs):