From f3c75f8b44eba4f3bcc43b65f10488203cd78a6c Mon Sep 17 00:00:00 2001 From: Srimanth Agastyaraju <30816357+asrimanth@users.noreply.github.com> Date: Mon, 6 Mar 2023 09:56:40 -0500 Subject: [PATCH] [Generate] Fix gradient_checkpointing and use_cache bug for BLOOM (#21956) Step 1 - Change use_cache fix --- src/transformers/models/bloom/modeling_bloom.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 27311ddde07..6cc0ef7c994 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -732,6 +732,13 @@ class BloomModel(BloomPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # Compute alibi tensor: check build_alibi_tensor documentation seq_length_with_past = seq_length past_key_values_length = 0 @@ -756,11 +763,6 @@ class BloomModel(BloomPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs):