mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix gradient checkpointing bug in git (#21818)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
50db741417
commit
e07a3d95f8
@ -431,6 +431,13 @@ class GitEncoder(nn.Module):
|
||||
pixel_values_present: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
@ -443,11 +450,6 @@ class GitEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
Loading…
Reference in New Issue
Block a user