mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
[Generate] Fix gradient_checkpointing and use_cache bug for BLOOM (#21956)
Step 1 - Change use_cache fix
This commit is contained in:
parent
934d0b8bdd
commit
f3c75f8b44
@ -732,6 +732,13 @@ class BloomModel(BloomPreTrainedModel):
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
@ -756,11 +763,6 @@ class BloomModel(BloomPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
Loading…
Reference in New Issue
Block a user