Apply GradientCheckpointingLayer to the whole repo (#38913)

* first batch (4)

* align

* altclip

* beit

* bert

* yolos

* dino, pvt_v2

* bark, bart, bert_generation

* big_bird, biogpt

* blnderbot, bloom

* bridgetower

* camambert, canine, chameleon

* chinese clip, clap, clip

* codegen, conditional detr, convbert

* dab_detr, data2vec

* dbrx, deberta

* deberta, decicion_tranformer, deformable_detr

* deit, deta, mctct

* detr, dinov2, distilbert

* donut, dpt, electra

* ernie, esm, falcon

* flava, fnet, falcon_mamba

* focalnet, git, gpt2

* gpt - bigcode, neo, neox

* gptj, groupvit

* idefics2, idefics3

* ijepa, imagegpt, internvl

* jetmoe, kosmos2, layoutlm

* layoutlm2-3, led

* lilt, longformer, longt5, luke

* m2m, mamba1-2

* marian, markuplm, mask2former

* maskformer

* mbart, megatron_bert, mimi

* mixtral, mlcd

* mobilevit1-2, modernbert

* moshi, mpt, mra

* mt5, musicgen

* mvp, nemotron

* nllb_moe

* nystromformer, omdet_turbo

* opt, owlvit, owlv2

* pegasus, pegasus_x, presimmon

* phimoe, pix2struct, pixtral

* plbart, pop2piano, prophetnet

* qwen2*

* qwen2, qwen3 moe,  rec gemma

* rembert

* roberta

* roberta prelayernorm

* roc_bert, roformer, rwkv

* sam, sam_hq

* seggpt, smolvlm, speech_to_text

* splinter, stablelm, swin

* swin2sr, switch_transformer, t5, table_transformer

* tapas, time_series_tranformer, timesformer

* trocr, tvp, umt5

* videomae, vilt, visual_bert

* vit, vit_mae, vit_msn

* vitpose_backbone, vits, vivit

* whisper. x_clip, xglm

* xlm_roberta, xmod

* yoso

* zamba

* vitdet, wav2vec2, wav2vec2_bert

* unispeech, wav2vec_conformer

* wavlm

* speecht5

* swinv2

* sew / _d

* seamless_mt4 / _v2

* deprecated models update

* bros

* gemma2, gemma3

* got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer

* fixup

* Add use_cache=False and past_key_value=None to  GradientCheckpointingLayer

* fixup

* fix prophetnet

* fix bigbird_pegasus

* fix blenderbot

* fix mbart

* fix mvp

* fix zamba2

* fix bart

* fix blenderbot_small

* fix codegen

* Update gradient checkpointing layer to support more past_key_values arg names

* fix data2vec vision

* fix deformable_detr

* fix gptj

* fix led

* fix m2m_100

* add comment

* fix nnlb_moe

* Fix pegasus_x

* fix plbart

* udop

* fix-copies: beit, wav2vec2

* fix gpt_bigcode

* fixup

* fix t5

* fix switch_transformers

* fix longt5

* fix mt5

* update tapas

* fix blip2

* update blip

* fix musicgen

* fix gpt2, trocr

* fix copies

* !!! Revert zamba, mllama

* update autoformer

* update bros

* update args / kwargs for BERT and copies

* 2nd round of updates

* update conditional detr

* Pass encoder_hidden_states as positional arg

* Update to pass encoder_decoder_position_bias as positional arg

* fixup

* biogpt modular

* modular gemma2

* modular gemma3

* modular gpt_neox

* modular informer

* modular internvl

* modular mixtral

* modular mlcd

* modular modernbert

* modular phi

* modular qwen2_5_omni

* modular qwen2_5_vl

* modular sam_hq

* modular sew

* wav2vec2_bert

* modular wav2vec2_conformer

* modular wavlm

* fixup

* Update by modular instructblipvideo

* modular data2vec_audio

* nit modular mistral

* apply modular minimax

* fix modular moonshine

* revert zamba2

* fix mask2former

* refactor idefics
This commit is contained in:
Pavel Iakubovskii 2025-06-23 13:24:48 +01:00 committed by GitHub
parent 07aab1af1e
commit 84d19be41e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
224 changed files with 2513 additions and 5280 deletions

View File

@ -16,6 +16,11 @@ from functools import partial
import torch.nn as nn import torch.nn as nn
from transformers.utils import logging
logger = logging.get_logger(__name__)
class GradientCheckpointingLayer(nn.Module): class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing. """Base class for layers with gradient checkpointing.
@ -44,5 +49,35 @@ class GradientCheckpointingLayer(nn.Module):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
do_warn = False
layer_name = self.__class__.__name__
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
if "use_cache" in kwargs and kwargs["use_cache"]:
kwargs["use_cache"] = False
message += " `use_cache=False`,"
do_warn = True
# different names for the same thing in different layers
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
kwargs["past_key_value"] = None
message += " `past_key_value=None`,"
do_warn = True
if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
kwargs["past_key_values"] = None
message += " `past_key_values=None`,"
do_warn = True
if "layer_past" in kwargs and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
message += " `layer_past=None`,"
do_warn = True
# warn if anything was changed
if do_warn:
message = message.rstrip(",") + "."
logger.warning(message)
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithNoAttention, BaseModelOutputWithNoAttention,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -827,7 +828,7 @@ class AlignTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
class AlignTextLayer(nn.Module): class AlignTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -953,27 +954,15 @@ class AlignTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -418,7 +419,7 @@ class AltRobertaOutput(nn.Module):
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
class AltRobertaLayer(nn.Module): class AltRobertaLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -544,27 +545,15 @@ class AltRobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -732,7 +721,7 @@ class AltCLIPMLP(nn.Module):
return hidden_states return hidden_states
class AltCLIPEncoderLayer(nn.Module): class AltCLIPEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AltCLIPConfig): def __init__(self, config: AltCLIPConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -848,21 +837,12 @@ class AltCLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -282,7 +283,7 @@ class ASTOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module): class ASTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ASTConfig) -> None: def __init__(self, config: ASTConfig) -> None:
@ -349,16 +350,7 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:

View File

@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
) )
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
@ -670,7 +671,7 @@ class AutoformerAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
class AutoformerEncoderLayer(nn.Module): class AutoformerEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig): def __init__(self, config: AutoformerConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -744,7 +745,7 @@ class AutoformerEncoderLayer(nn.Module):
return outputs return outputs
class AutoformerDecoderLayer(nn.Module): class AutoformerDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig): def __init__(self, config: AutoformerConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1042,21 +1043,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1186,6 +1178,12 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
# expand encoder attention mask # expand encoder attention mask
@ -1228,38 +1226,17 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
if use_cache: hidden_states,
logger.warning( attention_mask,
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." encoder_hidden_states, # as a positional argument for gradient checkpointing
) encoder_attention_mask=encoder_attention_mask,
use_cache = False layer_head_mask=(head_mask[idx] if head_mask is not None else None),
layer_outputs = self._gradient_checkpointing_func( cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
decoder_layer.__call__, past_key_value=past_key_value,
hidden_states, output_attentions=output_attentions,
attention_mask, use_cache=use_cache,
encoder_hidden_states, )
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
(hidden_states, residual_trend) = layer_outputs[0] (hidden_states, residual_trend) = layer_outputs[0]
trend = trend + residual_trend trend = trend + residual_trend

View File

@ -31,6 +31,7 @@ from ...generation.logits_process import (
) )
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import ( from ...utils import (
@ -309,7 +310,7 @@ class BarkMLP(nn.Module):
return hidden_states return hidden_states
class BarkBlock(nn.Module): class BarkBlock(GradientCheckpointingLayer):
def __init__(self, config, is_causal=False): def __init__(self, config, is_causal=False):
super().__init__() super().__init__()
@ -606,25 +607,14 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, past_key_values=past_layer_key_values,
hidden_states, attention_mask=attention_mask,
None, head_mask=head_mask[i],
attention_mask, use_cache=use_cache,
head_mask[i], output_attentions=output_attentions,
use_cache, )
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_values=past_layer_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
) )
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -270,7 +271,7 @@ class BartAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class BartEncoderLayer(nn.Module): class BartEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -341,7 +342,7 @@ class BartEncoderLayer(nn.Module):
return outputs return outputs
class BartDecoderLayer(nn.Module): class BartDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -875,21 +876,12 @@ class BartEncoder(BartPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1137,35 +1129,18 @@ class BartDecoder(BartPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask,
hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_values,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -26,6 +26,7 @@ from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BackboneOutput, BackboneOutput,
BaseModelOutput, BaseModelOutput,
@ -497,7 +498,7 @@ class BeitOutput(nn.Module):
return hidden_states return hidden_states
class BeitLayer(nn.Module): class BeitLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
@ -525,7 +526,7 @@ class BeitLayer(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional[torch.Tensor] = None, relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[tuple[int]] = None, resolution: Optional[tuple[int, int]] = None,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
@ -695,25 +696,14 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, head_mask=layer_head_mask,
hidden_states, output_attentions=output_attentions,
layer_head_mask, relative_position_bias=relative_position_bias,
output_attentions, interpolate_pos_encoding=interpolate_pos_encoding,
relative_position_bias, resolution=resolution,
interpolate_pos_encoding, )
resolution,
)
else:
layer_outputs = layer_module(
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -522,7 +523,7 @@ class BertOutput(nn.Module):
return hidden_states return hidden_states
class BertLayer(nn.Module): class BertLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -647,27 +648,15 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
@ -275,7 +276,7 @@ class BertGenerationOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration
class BertGenerationLayer(nn.Module): class BertGenerationLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -401,27 +402,15 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -1419,7 +1420,7 @@ class BigBirdOutput(nn.Module):
return hidden_states return hidden_states
class BigBirdLayer(nn.Module): class BigBirdLayer(GradientCheckpointingLayer):
def __init__(self, config, seed=None): def __init__(self, config, seed=None):
super().__init__() super().__init__()
self.config = config self.config = config
@ -1593,35 +1594,19 @@ class BigBirdEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, band_mask,
encoder_attention_mask, from_mask,
band_mask, to_mask,
from_mask, blocked_encoder_mask,
to_mask, past_key_value,
blocked_encoder_mask, output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
band_mask,
from_mask,
to_mask,
blocked_encoder_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
) )
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -1333,7 +1334,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class BigBirdPegasusEncoderLayer(nn.Module): class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BigBirdPegasusConfig, seed=None): def __init__(self, config: BigBirdPegasusConfig, seed=None):
super().__init__() super().__init__()
self.attention_type = config.attention_type self.attention_type = config.attention_type
@ -1420,7 +1421,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
self.self_attn.set_attention_type(value) self.self_attn.set_attention_type(value)
class BigBirdPegasusDecoderLayer(nn.Module): class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1947,31 +1948,17 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, band_mask=band_mask,
(head_mask[idx] if head_mask is not None else None), from_mask=from_mask,
band_mask, to_mask=to_mask,
from_mask, from_blocked_mask=blocked_encoder_mask,
to_mask, to_blocked_mask=blocked_encoder_mask,
blocked_encoder_mask, output_attentions=output_attentions,
blocked_encoder_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
band_mask=band_mask,
from_mask=from_mask,
to_mask=to_mask,
from_blocked_mask=blocked_encoder_mask,
to_blocked_mask=blocked_encoder_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -2297,35 +2284,18 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask,
hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_values,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -20,7 +20,6 @@
# limitations under the License. # limitations under the License.
import math import math
from functools import partial
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
@ -32,6 +31,7 @@ from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -248,7 +248,7 @@ class BioGptAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class BioGptDecoderLayer(nn.Module): class BioGptDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -646,30 +646,17 @@ class BioGptModel(BioGptPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), attention_mask=causal_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
causal_mask, past_key_value=past_key_values,
head_mask[idx] if head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, position_ids=position_ids,
use_cache, cache_position=cache_position,
position_ids, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -15,7 +15,6 @@
"""PyTorch BioGPT model.""" """PyTorch BioGPT model."""
import math import math
from functools import partial
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -473,30 +472,17 @@ class BioGptModel(BioGptPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), attention_mask=causal_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
causal_mask, past_key_value=past_key_values,
head_mask[idx] if head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, position_ids=position_ids,
use_cache, cache_position=cache_position,
position_ids, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
) )
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -270,7 +271,7 @@ class BlenderbotAttention(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotEncoderLayer(nn.Module): class BlenderbotEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -339,7 +340,7 @@ class BlenderbotEncoderLayer(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotDecoderLayer(nn.Module): class BlenderbotDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -825,21 +826,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1090,35 +1082,18 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, causal_mask,
hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
causal_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_values,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
) )
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -254,7 +255,7 @@ class BlenderbotSmallAttention(nn.Module):
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallEncoderLayer(nn.Module): class BlenderbotSmallEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -326,7 +327,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module): class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -812,21 +813,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1073,35 +1065,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, causal_mask,
hidden_states, encoder_hidden_states,
causal_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_values,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -552,7 +552,7 @@ class BlipEncoder(nn.Module):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )

View File

@ -531,7 +531,7 @@ class Blip2Encoder(nn.Module):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@ -992,11 +992,11 @@ class Blip2QFormerEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value, past_key_value=past_key_value,
output_attentions, output_attentions=output_attentions,
query_length, query_length=query_length,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from torch.nn import functional as F
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -366,7 +367,7 @@ class BloomMLP(nn.Module):
return output return output
class BloomBlock(nn.Module): class BloomBlock(GradientCheckpointingLayer):
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -605,29 +606,16 @@ class BloomModel(BloomPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past=past_key_values,
hidden_states, attention_mask=causal_mask,
alibi, head_mask=head_mask[i],
causal_mask, use_cache=use_cache,
past_key_values, output_attentions=output_attentions,
head_mask[i], alibi=alibi,
use_cache, cache_position=cache_position,
output_attentions, )
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache: if use_cache:

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN, QuickGELUActivation from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -662,7 +663,7 @@ class BridgeTowerBertCrossLayer(nn.Module):
return layer_output return layer_output
class BridgeTowerTextLayer(nn.Module): class BridgeTowerTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -788,27 +789,15 @@ class BridgeTowerTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -428,7 +429,7 @@ class BrosOutput(nn.Module):
return hidden_states return hidden_states
class BrosLayer(nn.Module): class BrosLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -550,34 +551,16 @@ class BrosEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: layer_outputs = layer_module(
if use_cache: hidden_states,
logger.warning( bbox_pos_emb,
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " attention_mask,
"`use_cache=False`..." layer_head_mask,
) encoder_hidden_states, # as a positional argument for gradient checkpointing
use_cache = False encoder_attention_mask=encoder_attention_mask,
layer_outputs = self._gradient_checkpointing_func( past_key_value=past_key_value,
layer_module.__call__, output_attentions=output_attentions,
hidden_states, )
bbox_pos_emb,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states=hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=layer_head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -478,7 +479,7 @@ class CamembertOutput(nn.Module):
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert
class CamembertLayer(nn.Module): class CamembertLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -604,27 +605,15 @@ class CamembertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
ModelOutput, ModelOutput,
@ -672,7 +673,7 @@ class CanineOutput(nn.Module):
return hidden_states return hidden_states
class CanineLayer(nn.Module): class CanineLayer(GradientCheckpointingLayer):
def __init__( def __init__(
self, self,
config, config,
@ -779,16 +780,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -383,7 +384,7 @@ class ChameleonAttention(nn.Module):
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
class ChameleonDecoderLayer(nn.Module): class ChameleonDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int): def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -458,7 +459,7 @@ class ChameleonDecoderLayer(nn.Module):
return outputs return outputs
class ChameleonSwinDecoderLayer(nn.Module): class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int): def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -1011,28 +1012,16 @@ class ChameleonModel(ChameleonPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=causal_mask,
hidden_states, position_ids=position_ids,
causal_mask, past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, **kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -577,7 +578,7 @@ class ChineseCLIPVisionMLP(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
class ChineseCLIPTextLayer(nn.Module): class ChineseCLIPTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -663,7 +664,7 @@ class ChineseCLIPTextLayer(nn.Module):
return layer_output return layer_output
class ChineseCLIPVisionLayer(nn.Module): class ChineseCLIPVisionLayer(GradientCheckpointingLayer):
def __init__(self, config: ChineseCLIPConfig): def __init__(self, config: ChineseCLIPConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -816,27 +817,15 @@ class ChineseCLIPTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -920,17 +909,10 @@ class ChineseCLIPVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, output_attentions=output_attentions,
hidden_states, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -691,7 +692,7 @@ class ClapAudioLayer(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio
class ClapAudioStage(nn.Module): class ClapAudioStage(GradientCheckpointingLayer):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__() super().__init__()
self.config = config self.config = config
@ -928,14 +929,9 @@ class ClapAudioEncoder(nn.Module):
input_dimensions = self.input_resolutions[i] input_dimensions = self.input_resolutions[i]
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions )
)
else:
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1355,7 +1351,7 @@ class ClapTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
class ClapTextLayer(nn.Module): class ClapTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -1481,27 +1477,15 @@ class ClapTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
@ -393,7 +394,7 @@ class CLIPMLP(nn.Module):
return hidden_states return hidden_states
class CLIPEncoderLayer(nn.Module): class CLIPEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]): def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -575,21 +576,12 @@ class CLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -374,7 +375,7 @@ class CLIPSegMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
class CLIPSegEncoderLayer(nn.Module): class CLIPSegEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: CLIPSegConfig): def __init__(self, config: CLIPSegConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -539,22 +540,12 @@ class CLIPSegEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:

View File

@ -24,6 +24,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
@ -245,7 +246,7 @@ class CodeGenMLP(nn.Module):
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(nn.Module): class CodeGenBlock(GradientCheckpointingLayer):
# Ignore copy # Ignore copy
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
@ -437,29 +438,16 @@ class CodeGenModel(CodeGenPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past=past_key_values,
hidden_states, attention_mask=causal_mask,
None, position_ids=position_ids,
causal_mask, head_mask=head_mask[i],
position_ids, use_cache=use_cache,
head_mask[i], output_attentions=output_attentions,
use_cache, cache_position=cache_position,
output_attentions, )
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends
@ -827,7 +828,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
return outputs return outputs
class ConditionalDetrDecoderLayer(nn.Module): class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ConditionalDetrConfig): def __init__(self, config: ConditionalDetrConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1297,31 +1298,18 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
pos_transformation = self.query_scale(hidden_states) pos_transformation = self.query_scale(hidden_states)
# apply transformation # apply transformation
query_sine_embed = query_sine_embed_before_transformation * pos_transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = decoder_layer(
decoder_layer.__call__, hidden_states,
hidden_states, None, # attention_mask
None, object_queries,
object_queries, query_position_embeddings,
query_position_embeddings, query_sine_embed,
query_sine_embed, encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_hidden_states, encoder_attention_mask=encoder_attention_mask,
encoder_attention_mask, output_attentions=output_attentions,
None, is_first=(idx == 0),
None, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
query_sine_embed=query_sine_embed,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
is_first=(idx == 0),
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
@ -532,7 +533,7 @@ class ConvBertOutput(nn.Module):
return hidden_states return hidden_states
class ConvBertLayer(nn.Module): class ConvBertLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -620,25 +621,14 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, output_attentions,
encoder_attention_mask, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
@ -702,7 +703,7 @@ class DabDetrDecoderLayerFFN(nn.Module):
# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig # Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig
class DabDetrEncoderLayer(nn.Module): class DabDetrEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DabDetrConfig): def __init__(self, config: DabDetrConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -764,7 +765,7 @@ class DabDetrEncoderLayer(nn.Module):
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr
class DabDetrDecoderLayer(nn.Module): class DabDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DabDetrConfig, is_first: bool = False): def __init__(self, config: DabDetrConfig, is_first: bool = False):
super().__init__() super().__init__()
self.self_attn = DabDetrDecoderLayerSelfAttention(config) self.self_attn = DabDetrDecoderLayerSelfAttention(config)
@ -976,21 +977,12 @@ class DabDetrEncoder(DabDetrPreTrainedModel):
# we add object_queries * pos_scaler as extra input to the encoder_layer # we add object_queries * pos_scaler as extra input to the encoder_layer
scaled_object_queries = object_queries * pos_scales scaled_object_queries = object_queries * pos_scales
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask=attention_mask,
hidden_states, object_queries=scaled_object_queries,
attention_mask, output_attentions=output_attentions,
scaled_object_queries, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
object_queries=scaled_object_queries,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1138,29 +1130,16 @@ class DabDetrDecoder(DabDetrPreTrainedModel):
reference_anchor_size[..., 1] / obj_center[..., 3] reference_anchor_size[..., 1] / obj_center[..., 3]
).unsqueeze(-1) ).unsqueeze(-1)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, None, # attention_mask
hidden_states, object_queries,
None, query_pos,
object_queries, query_sine_embed,
query_pos, encoder_hidden_states, # as a positional argument for gradient checkpointing
query_sine_embed, encoder_attention_mask=memory_key_padding_mask,
encoder_hidden_states, output_attentions=output_attentions,
memory_key_padding_mask, )
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_pos,
query_sine_embed=query_sine_embed,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=memory_key_padding_mask,
output_attentions=output_attentions,
)
# iter update # iter update
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -33,6 +33,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
CausalLMOutput, CausalLMOutput,
@ -51,7 +52,7 @@ if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask from ...integrations.flex_attention import make_flex_block_causal_mask
class Data2VecAudioConvLayer(nn.Module): class Data2VecAudioConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -155,13 +156,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = conv_layer(hidden_states)
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states return hidden_states
@ -357,7 +352,7 @@ class Data2VecAudioFeedForward(nn.Module):
return hidden_states return hidden_states
class Data2VecAudioEncoderLayer(nn.Module): class Data2VecAudioEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = Data2VecAudioAttention( self.attention = Data2VecAudioAttention(
@ -441,17 +436,9 @@ class Data2VecAudioEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus: if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync # under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: layer_outputs = layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
layer.__call__, )
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer: if skip_the_layer:

View File

@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -375,7 +376,7 @@ class Data2VecTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
class Data2VecTextLayer(nn.Module): class Data2VecTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -501,27 +502,15 @@ class Data2VecTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -497,7 +498,7 @@ class Data2VecVisionOutput(nn.Module):
# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision # Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
class Data2VecVisionLayer(nn.Module): class Data2VecVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__( def __init__(
@ -527,7 +528,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional[torch.Tensor] = None, relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[tuple[int]] = None, resolution: Optional[tuple[int, int]] = None,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
@ -699,25 +700,14 @@ class Data2VecVisionEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, head_mask=layer_head_mask,
hidden_states, output_attentions=output_attentions,
layer_head_mask, relative_position_bias=relative_position_bias,
output_attentions, interpolate_pos_encoding=interpolate_pos_encoding,
relative_position_bias, resolution=resolution,
interpolate_pos_encoding, )
resolution,
)
else:
layer_outputs = layer_module(
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -20,6 +20,7 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import Wav2Vec2BaseModelOutput from ...modeling_outputs import Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ..wav2vec2.modeling_wav2vec2 import ( from ..wav2vec2.modeling_wav2vec2 import (
@ -38,7 +39,7 @@ from ..wav2vec2.modeling_wav2vec2 import (
from .configuration_data2vec_audio import Data2VecAudioConfig from .configuration_data2vec_audio import Data2VecAudioConfig
class Data2VecAudioConvLayer(nn.Module): class Data2VecAudioConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1

View File

@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils import auto_docstring, is_torch_flex_attn_available, logging
@ -724,7 +725,7 @@ class DbrxFFN(nn.Module):
return out, weights return out, weights
class DbrxBlock(nn.Module): class DbrxBlock(GradientCheckpointingLayer):
def __init__(self, config: DbrxConfig, block_idx: int): def __init__(self, config: DbrxConfig, block_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.d_model self.hidden_size = config.d_model
@ -947,29 +948,16 @@ class DbrxModel(DbrxPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: block_outputs = block(
block_outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, attention_mask=causal_mask,
hidden_states, position_ids=position_ids,
causal_mask, past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, output_router_logits=output_router_logits,
output_attentions, use_cache=use_cache,
output_router_logits, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
block_outputs = block(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = block_outputs[0] hidden_states = block_outputs[0]

View File

@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
@ -492,7 +493,7 @@ class DebertaOutput(nn.Module):
return hidden_states return hidden_states
class DebertaLayer(nn.Module): class DebertaLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = DebertaAttention(config) self.attention = DebertaAttention(config)
@ -580,25 +581,14 @@ class DebertaEncoder(nn.Module):
rel_embeddings = self.get_rel_embedding() rel_embeddings = self.get_rel_embedding()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training: hidden_states, att_m = layer_module(
hidden_states, att_m = self._gradient_checkpointing_func( next_kv,
layer_module.__call__, attention_mask,
next_kv, query_states=query_states,
attention_mask, relative_pos=relative_pos,
query_states, rel_embeddings=rel_embeddings,
relative_pos, output_attentions=output_attentions,
rel_embeddings, )
output_attentions,
)
else:
hidden_states, att_m = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
@ -418,7 +419,7 @@ class DebertaV2Output(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
class DebertaV2Layer(nn.Module): class DebertaV2Layer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = DebertaV2Attention(config) self.attention = DebertaV2Attention(config)
@ -655,25 +656,14 @@ class DebertaV2Encoder(nn.Module):
next_kv = hidden_states next_kv = hidden_states
rel_embeddings = self.get_rel_embedding() rel_embeddings = self.get_rel_embedding()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training: output_states, attn_weights = layer_module(
output_states, attn_weights = self._gradient_checkpointing_func( next_kv,
layer_module.__call__, attention_mask,
next_kv, query_states=query_states,
attention_mask, relative_pos=relative_pos,
query_states, rel_embeddings=rel_embeddings,
relative_pos, output_attentions=output_attentions,
rel_embeddings, )
output_attentions,
)
else:
output_states, attn_weights = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (attn_weights,) all_attentions = all_attentions + (attn_weights,)

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
@ -360,7 +361,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module): class DecisionTransformerGPT2Block(GradientCheckpointingLayer):
# Ignore copy # Ignore copy
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
@ -654,31 +655,17 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, past_key_values if not (self.gradient_checkpointing and self.training) else None,
hidden_states, cache_position,
None, attention_mask,
None, head_mask[i],
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
head_mask[i], encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, use_cache=use_cache,
encoder_attention_mask, output_attentions=output_attentions,
use_cache, )
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@ -27,6 +27,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
@ -759,7 +760,7 @@ class DeformableDetrMultiheadAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
class DeformableDetrEncoderLayer(nn.Module): class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DeformableDetrConfig): def __init__(self, config: DeformableDetrConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -848,7 +849,7 @@ class DeformableDetrEncoderLayer(nn.Module):
return outputs return outputs
class DeformableDetrDecoderLayer(nn.Module): class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DeformableDetrConfig): def __init__(self, config: DeformableDetrConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1126,29 +1127,16 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
for i, encoder_layer in enumerate(self.layers): for i, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, position_embeddings=position_embeddings,
attention_mask, reference_points=reference_points,
position_embeddings, spatial_shapes=spatial_shapes,
reference_points, spatial_shapes_list=spatial_shapes_list,
spatial_shapes, level_start_index=level_start_index,
spatial_shapes_list, output_attentions=output_attentions,
level_start_index, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1273,31 +1261,17 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, position_embeddings,
hidden_states, reference_points_input,
position_embeddings, spatial_shapes,
reference_points_input, spatial_shapes_list,
spatial_shapes, level_start_index,
spatial_shapes_list, encoder_hidden_states, # as a positional argument for gradient checkpointing
level_start_index, encoder_attention_mask,
encoder_hidden_states, output_attentions,
encoder_attention_mask, )
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -347,7 +348,7 @@ class DeiTOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module): class DeiTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: DeiTConfig) -> None: def __init__(self, config: DeiTConfig) -> None:
@ -414,15 +415,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -39,6 +39,7 @@ from ....file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput from ....modeling_outputs import BaseModelOutput
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import meshgrid from ....pytorch_utils import meshgrid
@ -909,7 +910,7 @@ class DetaEncoderLayer(nn.Module):
return outputs return outputs
class DetaDecoderLayer(nn.Module): class DetaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetaConfig): def __init__(self, config: DetaConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1341,29 +1342,16 @@ class DetaDecoder(DetaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, position_embeddings=position_embeddings,
hidden_states, encoder_hidden_states=encoder_hidden_states,
position_embeddings, reference_points=reference_points_input,
reference_points_input, spatial_shapes=spatial_shapes,
spatial_shapes, level_start_index=level_start_index,
level_start_index, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, )
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add
from ....integrations.deepspeed import is_deepspeed_zero3_enabled from ....integrations.deepspeed import is_deepspeed_zero3_enabled
from ....integrations.fsdp import is_fsdp_managed_module from ....integrations.fsdp import is_fsdp_managed_module
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_outputs import BaseModelOutput, CausalLMOutput
from ....modeling_utils import ( from ....modeling_utils import (
PreTrainedModel, PreTrainedModel,
@ -377,7 +378,7 @@ class MCTCTOutput(nn.Module):
return hidden_states return hidden_states
class MCTCTLayer(nn.Module): class MCTCTLayer(GradientCheckpointingLayer):
def __init__(self, config: MCTCTConfig): def __init__(self, config: MCTCTConfig):
super().__init__() super().__init__()
@ -591,20 +592,11 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus: if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync # under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states=hidden_states,
encoder_layer.__call__, attention_mask=attention_mask,
hidden_states, output_attentions=output_attentions,
attention_mask, )
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import ( from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -438,7 +439,7 @@ class NezhaOutput(nn.Module):
return hidden_states return hidden_states
class NezhaLayer(nn.Module): class NezhaLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -563,27 +564,15 @@ class NezhaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, past_key_value,
encoder_attention_mask, output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@ -339,7 +340,7 @@ class OpenLlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class OpenLlamaDecoderLayer(nn.Module): class OpenLlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: OpenLlamaConfig): def __init__(self, config: OpenLlamaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -631,25 +632,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=attention_mask,
hidden_states, position_ids=position_ids,
attention_mask, past_key_value=past_key_value,
position_ids, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, )
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import ( from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -452,7 +453,7 @@ class QDQBertOutput(nn.Module):
# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert # Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
class QDQBertLayer(nn.Module): class QDQBertLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.seq_len_dim = 1 self.seq_len_dim = 1
@ -568,32 +569,15 @@ class QDQBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
if use_cache: hidden_states,
logger.warning_once( attention_mask,
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." layer_head_mask,
) encoder_hidden_states,
use_cache = False encoder_attention_mask,
layer_outputs = self._gradient_checkpointing_func( past_key_value,
layer_module.__call__, output_attentions,
hidden_states, )
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import ( from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -447,7 +448,7 @@ class RealmOutput(nn.Module):
return hidden_states return hidden_states
class RealmLayer(nn.Module): class RealmLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -572,27 +573,15 @@ class RealmEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, past_key_value,
encoder_attention_mask, output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....utils import add_start_docstrings, logging, replace_return_docstrings from ....utils import add_start_docstrings, logging, replace_return_docstrings
@ -263,7 +264,7 @@ class Speech2Text2Attention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
class Speech2Text2DecoderLayer(nn.Module): class Speech2Text2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Speech2Text2Config): def __init__(self, config: Speech2Text2Config):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -612,31 +613,17 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=attention_mask,
hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_value,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
) )
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....utils import ( from ....utils import (
ModelOutput, ModelOutput,
@ -346,7 +347,7 @@ class CausalSelfAttention(nn.Module):
return outputs return outputs
class Block(nn.Module): class Block(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd) self.ln1 = nn.LayerNorm(config.n_embd)
@ -540,16 +541,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(hidden_states, layer_past, use_cache, output_attentions)
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(hidden_states, layer_past, use_cache, output_attentions)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -483,7 +484,7 @@ class TvltOutput(nn.Module):
return hidden_states return hidden_states
class TvltLayer(nn.Module): class TvltLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config): def __init__(self, config):
@ -546,16 +547,7 @@ class TvltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -853,15 +845,7 @@ class TvltDecoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -390,7 +391,7 @@ VIT_HYBRID_ATTENTION_CLASSES = {
} }
class ViTHybridLayer(nn.Module): class ViTHybridLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTHybridConfig) -> None: def __init__(self, config: ViTHybridConfig) -> None:
@ -457,15 +458,7 @@ class ViTHybridEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import Tensor, nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from ....activations import ACT2FN from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput from ....modeling_outputs import BaseModelOutput
from ....modeling_utils import PreTrainedModel from ....modeling_utils import PreTrainedModel
from ....utils import ( from ....utils import (
@ -1090,7 +1091,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
return predict_relative_pos_embeddings return predict_relative_pos_embeddings
class XLMProphetNetEncoderLayer(nn.Module): class XLMProphetNetEncoderLayer(GradientCheckpointingLayer):
""" """
Encoder block for XLMProphetnet Encoder block for XLMProphetnet
""" """
@ -1133,7 +1134,7 @@ class XLMProphetNetEncoderLayer(nn.Module):
return outputs return outputs
class XLMProphetNetDecoderLayer(nn.Module): class XLMProphetNetDecoderLayer(GradientCheckpointingLayer):
""" """
Decoder block for XLMProphetnet Decoder block for XLMProphetnet
""" """
@ -1320,21 +1321,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask=extended_attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
extended_attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1554,41 +1546,21 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=extended_attention_mask,
hidden_states, encoder_hidden_states=encoder_hidden_states,
extended_attention_mask, encoder_attn_mask=extended_encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
extended_encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
(head_mask[idx] if head_mask is not None else None), extended_predict_attention_mask=extended_predict_attention_mask,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), main_relative_position_buckets=main_relative_position_buckets,
extended_predict_attention_mask, predict_relative_position_buckets=predict_relative_position_buckets,
main_relative_position_buckets, position_ids=position_ids,
predict_relative_position_buckets, past_key_value=past_key_value,
position_ids, use_cache=use_cache,
None, output_attentions=output_attentions,
use_cache, )
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
@ -677,7 +678,7 @@ class DetrEncoderLayer(nn.Module):
return outputs return outputs
class DetrDecoderLayer(nn.Module): class DetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetrConfig): def __init__(self, config: DetrConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1045,25 +1046,15 @@ class DetrDecoder(DetrPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, combined_attention_mask,
hidden_states, object_queries,
combined_attention_mask, query_position_embeddings,
encoder_hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
None, output_attentions=output_attentions,
) )
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -382,7 +383,7 @@ class Dinov2SwiGLUFFN(nn.Module):
return self.weights_out(hidden) return self.weights_out(hidden)
class Dinov2Layer(nn.Module): class Dinov2Layer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the original implementation.""" """This corresponds to the Block class in the original implementation."""
def __init__(self, config: Dinov2Config) -> None: def __init__(self, config: Dinov2Config) -> None:
@ -458,15 +459,7 @@ class Dinov2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -28,6 +28,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -399,7 +400,7 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
return self.weights_out(hidden) return self.weights_out(hidden)
class Dinov2WithRegistersLayer(nn.Module): class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the original implementation.""" """This corresponds to the Block class in the original implementation."""
def __init__(self, config: Dinov2WithRegistersConfig) -> None: def __init__(self, config: Dinov2WithRegistersConfig) -> None:
@ -476,15 +477,7 @@ class Dinov2WithRegistersEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ from ...configuration_utils import PretrainedConfig
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
@ -441,7 +442,7 @@ DISTILBERT_ATTENTION_CLASSES = {
} }
class TransformerBlock(nn.Module): class TransformerBlock(GradientCheckpointingLayer):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
@ -537,21 +538,12 @@ class Transformer(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_state,
layer_module.__call__, attn_mask,
hidden_state, head_mask[i],
attn_mask, output_attentions,
head_mask[i], )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_state,
attn_mask,
head_mask[i],
output_attentions,
)
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]

View File

@ -27,6 +27,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -706,7 +707,7 @@ class DonutSwinLayer(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module): class DonutSwinStage(GradientCheckpointingLayer):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__() super().__init__()
self.config = config self.config = config
@ -816,19 +817,9 @@ class DonutSwinEncoder(nn.Module):
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
layer_module.__call__, )
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1] hidden_states_before_downsampling = layer_outputs[1]

View File

@ -29,6 +29,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -469,7 +470,7 @@ class DPTViTOutput(nn.Module):
# copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput # copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput
class DPTViTLayer(nn.Module): class DPTViTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: DPTConfig) -> None: def __init__(self, config: DPTConfig) -> None:
@ -536,15 +537,7 @@ class DPTViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -436,7 +437,7 @@ class ElectraOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
class ElectraLayer(nn.Module): class ElectraLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -562,27 +563,15 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -361,7 +362,7 @@ class ErnieOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie
class ErnieLayer(nn.Module): class ErnieLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -487,27 +488,15 @@ class ErnieEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -599,7 +600,7 @@ class EsmOutput(nn.Module):
return hidden_states return hidden_states
class EsmLayer(nn.Module): class EsmLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -725,27 +726,15 @@ class EsmEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, past_key_value,
encoder_attention_mask, output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
) )
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -556,7 +557,7 @@ FALCON_ATTENTION_CLASSES = {
} }
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: FalconConfig, layer_idx=None): def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -836,33 +837,18 @@ class FalconModel(FalconPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past=past_key_values,
hidden_states, attention_mask=causal_mask,
alibi, position_ids=position_ids,
causal_mask, head_mask=head_mask[i],
position_ids, use_cache=use_cache,
head_mask[i], output_attentions=output_attentions,
past_key_values, alibi=alibi,
use_cache, cache_position=cache_position,
output_attentions, position_embeddings=position_embeddings,
cache_position, )
position_embeddings,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:

View File

@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import MambaCache from ...cache_utils import MambaCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
@ -405,7 +406,7 @@ class FalconMambaRMSNorm(nn.Module):
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache # Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
class FalconMambaBlock(nn.Module): class FalconMambaBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.config = config self.config = config
@ -620,17 +621,12 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers: for mixer_block in self.layers:
if self.gradient_checkpointing and self.training: hidden_states = mixer_block(
hidden_states = self._gradient_checkpointing_func( hidden_states,
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask cache_params=cache_params,
) cache_position=cache_position,
else: attention_mask=attention_mask,
hidden_states = mixer_block( )
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -577,7 +578,7 @@ class FlavaOutput(nn.Module):
return hidden_states return hidden_states
class FlavaLayer(nn.Module): class FlavaLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: FlavaPossibleConfigs) -> None: def __init__(self, config: FlavaPossibleConfigs) -> None:
@ -648,16 +649,7 @@ class FlavaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ if is_scipy_available():
from scipy import linalg from scipy import linalg
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -235,7 +236,7 @@ class FNetOutput(nn.Module):
return hidden_states return hidden_states
class FNetLayer(nn.Module): class FNetLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -276,10 +277,7 @@ class FNetEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states)
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
else:
layer_outputs = layer_module(hidden_states)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, logging
@ -455,7 +456,7 @@ class FocalNetLayer(nn.Module):
return hidden_state return hidden_state
class FocalNetStage(nn.Module): class FocalNetStage(GradientCheckpointingLayer):
def __init__(self, config, index, input_resolution): def __init__(self, config, index, input_resolution):
super().__init__() super().__init__()
@ -560,14 +561,7 @@ class FocalNetEncoder(nn.Module):
all_reshaped_hidden_states += (reshaped_hidden_state,) all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, stage_module in enumerate(self.stages): for i, stage_module in enumerate(self.stages):
if self.gradient_checkpointing and self.training: stage_outputs = stage_module(hidden_states, input_dimensions)
stage_outputs = self._gradient_checkpointing_func(
stage_module.__call__,
hidden_states,
input_dimensions,
)
else:
stage_outputs = stage_module(hidden_states, input_dimensions)
hidden_states = stage_outputs[0] hidden_states = stage_outputs[0]
hidden_states_before_downsampling = stage_outputs[1] hidden_states_before_downsampling = stage_outputs[1]

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
@ -30,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -238,7 +238,7 @@ class Gemma2Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma2Config, layer_idx: int): def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -466,30 +466,17 @@ class Gemma2Model(Gemma2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), position_embeddings=position_embeddings,
hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings, position_ids=position_ids,
causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
@ -25,6 +24,7 @@ from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -303,7 +303,7 @@ class Gemma2Attention(GemmaAttention):
return attn_output, attn_weights return attn_output, attn_weights
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma2Config, layer_idx: int): def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -449,30 +449,17 @@ class Gemma2Model(GemmaModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), position_embeddings=position_embeddings,
hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings, position_ids=position_ids,
causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -22,7 +22,6 @@
import copy import copy
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -34,6 +33,7 @@ from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -364,7 +364,7 @@ class Gemma3Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class Gemma3DecoderLayer(nn.Module): class Gemma3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma3TextConfig, layer_idx: int): def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__() super().__init__()
self.config = config self.config = config
@ -581,32 +581,18 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), position_embeddings_global=position_embeddings_global,
hidden_states, position_embeddings_local=position_embeddings_local,
position_embeddings_global, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings_local, position_ids=position_ids,
causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -16,7 +16,6 @@
import copy import copy
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@ -27,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@ -443,7 +443,7 @@ class Gemma3Attention(Gemma2Attention):
return attn_output, attn_weights return attn_output, attn_weights
class Gemma3DecoderLayer(nn.Module): class Gemma3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma3TextConfig, layer_idx: int): def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__() super().__init__()
self.config = config self.config = config
@ -632,32 +632,18 @@ class Gemma3TextModel(Gemma2Model):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **flash_attn_kwargs), position_embeddings_global=position_embeddings_global,
hidden_states, position_embeddings_local=position_embeddings_local,
position_embeddings_global, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings_local, position_ids=position_ids,
causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, **flash_attn_kwargs,
cache_position, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -343,7 +344,7 @@ class GitOutput(nn.Module):
return hidden_states return hidden_states
class GitLayer(nn.Module): class GitLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -441,24 +442,14 @@ class GitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, past_key_values,
layer_head_mask, output_attentions,
past_key_values, pixel_values_present,
output_attentions, )
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
past_key_values,
output_attentions,
pixel_values_present,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -723,7 +714,7 @@ class GitVisionAttention(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
class GitVisionEncoderLayer(nn.Module): class GitVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GitVisionConfig): def __init__(self, config: GitVisionConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -840,21 +831,12 @@ class GitVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -192,7 +193,7 @@ class GotOcr2VisionAttention(nn.Module):
return outputs return outputs
class GotOcr2VisionLayer(nn.Module): class GotOcr2VisionLayer(GradientCheckpointingLayer):
def __init__(self, config, window_size): def __init__(self, config, window_size):
super().__init__() super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -463,13 +464,7 @@ class GotOcr2VisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -30,6 +30,7 @@ from ...activations import ACT2FN, get_activation
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -368,7 +369,7 @@ class GPT2MLP(nn.Module):
return hidden_states return hidden_states
class GPT2Block(nn.Module): class GPT2Block(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -922,32 +923,18 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, past_key_values if not (self.gradient_checkpointing and self.training) else None,
hidden_states, cache_position,
past_key_values, causal_mask,
cache_position, head_mask[i],
causal_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
head_mask[i], encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, use_cache=use_cache,
encoder_attention_mask, output_attentions=output_attentions,
use_cache, **kwargs,
output_attentions, )
)
else:
outputs = block(
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=causal_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@ -25,6 +25,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -558,7 +559,7 @@ GPTBIGCODE_ATTENTION_CLASSES = {
} }
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -759,6 +760,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -891,29 +898,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past,
hidden_states, attention_mask,
None, head_mask[i],
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
head_mask[i], encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, use_cache=use_cache,
encoder_attention_mask, output_attentions=output_attentions,
use_cache, )
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache: if use_cache:

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -431,7 +432,7 @@ class GPTNeoMLP(nn.Module):
return hidden_states return hidden_states
class GPTNeoBlock(nn.Module): class GPTNeoBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_id=None): def __init__(self, config, layer_id=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -635,27 +636,15 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past=past_key_values,
hidden_states, attention_mask=causal_mask,
None, head_mask=head_mask[i],
causal_mask, use_cache=use_cache,
head_mask[i], output_attentions=output_attentions,
use_cache, cache_position=cache_position,
output_attentions, )
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache: if use_cache:

View File

@ -14,6 +14,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -190,7 +191,7 @@ class GPTNeoXAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
@ -415,32 +416,18 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = layer(
outputs = self._gradient_checkpointing_func( hidden_states,
layer.__call__, attention_mask=causal_mask,
hidden_states, position_ids=position_ids,
causal_mask, head_mask=head_mask[i],
position_ids, layer_past=past_key_values,
head_mask[i], use_cache=use_cache,
use_cache, output_attentions=output_attentions,
past_key_values, cache_position=cache_position,
output_attentions, position_embeddings=position_embeddings,
cache_position, **flash_attn_kwargs,
position_embeddings, )
)
else:
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if output_attentions: if output_attentions:

View File

@ -9,6 +9,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -177,7 +178,7 @@ class GPTNeoXAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
@ -362,32 +363,18 @@ class GPTNeoXModel(LlamaModel, nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = layer(
outputs = self._gradient_checkpointing_func( hidden_states,
layer.__call__, attention_mask=causal_mask,
hidden_states, position_ids=position_ids,
causal_mask, head_mask=head_mask[i],
position_ids, layer_past=past_key_values,
head_mask[i], use_cache=use_cache,
use_cache, output_attentions=output_attentions,
past_key_values, cache_position=cache_position,
output_attentions, position_embeddings=position_embeddings,
cache_position, **flash_attn_kwargs,
position_embeddings, )
)
else:
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if output_attentions: if output_attentions:

View File

@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -434,7 +435,7 @@ class GPTJMLP(nn.Module):
return hidden_states return hidden_states
class GPTJBlock(nn.Module): class GPTJBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
@ -738,29 +739,16 @@ class GPTJModel(GPTJPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past=past_key_values,
hidden_states, attention_mask=causal_mask,
None, position_ids=position_ids,
causal_mask, head_mask=head_mask[i],
position_ids, use_cache=use_cache,
head_mask[i], output_attentions=output_attentions,
use_cache, cache_position=cache_position,
output_attentions, )
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -692,7 +693,7 @@ class GroupViTAttention(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
class GroupViTEncoderLayer(nn.Module): class GroupViTEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GroupViTConfig): def __init__(self, config: GroupViTConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -906,21 +907,12 @@ class GroupViTTextEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BackboneOutput, BackboneOutput,
BaseModelOutput, BaseModelOutput,
@ -540,7 +541,7 @@ class HieraLayer(nn.Module):
return (hidden_states, attn_weights) return (hidden_states, attn_weights)
class HieraStage(nn.Module): class HieraStage(GradientCheckpointingLayer):
def __init__( def __init__(
self, self,
config, config,
@ -734,12 +735,7 @@ class HieraEncoder(nn.Module):
for i, stage_module in enumerate(self.stages): for i, stage_module in enumerate(self.stages):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
stage_module.__call__, hidden_states, layer_head_mask, output_attentions
)
else:
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -32,6 +32,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -107,7 +108,7 @@ class HubertSamePadLayer(nn.Module):
return hidden_states return hidden_states
class HubertNoLayerNormConvLayer(nn.Module): class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -128,7 +129,7 @@ class HubertNoLayerNormConvLayer(nn.Module):
return hidden_states return hidden_states
class HubertLayerNormConvLayer(nn.Module): class HubertLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -155,7 +156,7 @@ class HubertLayerNormConvLayer(nn.Module):
return hidden_states return hidden_states
class HubertGroupNormConvLayer(nn.Module): class HubertGroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -212,13 +213,7 @@ class HubertFeatureEncoder(nn.Module):
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = conv_layer(hidden_states)
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states return hidden_states
@ -417,7 +412,7 @@ class HubertFeedForward(nn.Module):
return hidden_states return hidden_states
class HubertEncoderLayer(nn.Module): class HubertEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = HubertAttention( self.attention = HubertAttention(
@ -501,17 +496,9 @@ class HubertEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus: if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync # under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: layer_outputs = layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
layer.__call__, )
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer: if skip_the_layer:
@ -579,7 +566,7 @@ class HubertAttnAdapterLayer(nn.Module):
return hidden_states return hidden_states
class HubertEncoderLayerStableLayerNorm(nn.Module): class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = HubertAttention( self.attention = HubertAttention(
@ -675,17 +662,9 @@ class HubertEncoderStableLayerNorm(nn.Module):
if not skip_the_layer or synced_gpus: if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync # under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: layer_outputs = layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
layer.__call__, )
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer: if skip_the_layer:

View File

@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -668,7 +669,7 @@ class IdeficsAttention(nn.Module):
# this was adapted from LlamaDecoderLayer # this was adapted from LlamaDecoderLayer
class IdeficsDecoderLayer(nn.Module): class IdeficsDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -749,7 +750,7 @@ class IdeficsDecoderLayer(nn.Module):
return outputs return outputs
class IdeficsGatedCrossAttentionLayer(nn.Module): class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -1185,95 +1186,32 @@ class IdeficsModel(IdeficsPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
def vblock( # TODO(ls): Add cross attention values to respective lists
main_block, if idx % self.cross_layer_interval == 0:
hidden_states, cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval]
attention_mask, outputs = cross_attn_block(
position_ids,
past_key_value,
image_hidden_states,
image_attention_mask,
cross_attention_gate,
output_attentions,
use_cache,
layer_idx,
cross_layer_interval,
gated_cross_attn_layers,
cache_position,
):
# TODO(ls): Add cross attention values to respective lists
if layer_idx % cross_layer_interval == 0:
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
outputs = xblock(
hidden_states,
attention_mask=attention_mask,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
cross_attention_gate=cross_attention_gate,
output_attentions=output_attentions,
use_cache=use_cache,
past_key_value=None, # not implemented
**kwargs,
)
hidden_states = outputs[0]
layer_outputs = main_block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
return layer_outputs
if self.gradient_checkpointing and self.training:
past_key_values = None
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
vblock,
decoder_layer,
hidden_states, hidden_states,
attention_mask, attention_mask,
position_ids,
past_key_values,
image_hidden_states, image_hidden_states,
image_attention_mask,
cross_attention_gate,
output_attentions,
use_cache,
idx,
self.cross_layer_interval,
self.gated_cross_attn_layers,
cache_position,
)
else:
layer_outputs = vblock(
decoder_layer,
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask, image_attention_mask=image_attention_mask,
cross_attention_gate=cross_attention_gate, cross_attention_gate=cross_attention_gate,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
layer_idx=idx, past_key_value=None, # not implemented
cross_layer_interval=self.cross_layer_interval,
gated_cross_attn_layers=self.gated_cross_attn_layers,
cache_position=cache_position,
**kwargs, **kwargs,
) )
hidden_states = outputs[0]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import ( from ...utils import (
@ -283,7 +284,7 @@ class IdeficsVisionMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
class IdeficsVisionEncoderLayer(nn.Module): class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsVisionConfig): def __init__(self, config: IdeficsVisionConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -400,21 +401,12 @@ class IdeficsVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -339,7 +340,7 @@ class Idefics2MultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0] return hidden_state[:, 0]
class Idefics2EncoderLayer(nn.Module): class Idefics2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Idefics2VisionConfig): def __init__(self, config: Idefics2VisionConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -448,19 +449,11 @@ class Idefics2Encoder(nn.Module):
for encoder_layer in self.layers: for encoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, output_attentions=output_attentions,
attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ...cache_utils import DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -300,7 +301,7 @@ class Idefics3SimpleMLP(nn.Module):
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3 # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
class Idefics3EncoderLayer(nn.Module): class Idefics3EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Idefics3VisionConfig): def __init__(self, config: Idefics3VisionConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -409,19 +410,11 @@ class Idefics3Encoder(nn.Module):
for encoder_layer in self.layers: for encoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, output_attentions=output_attentions,
attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -12,6 +12,7 @@ import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -357,7 +358,7 @@ class IJepaOutput(nn.Module):
return hidden_states return hidden_states
class IJepaLayer(nn.Module): class IJepaLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: IJepaConfig) -> None: def __init__(self, config: IJepaConfig) -> None:
@ -423,15 +424,7 @@ class IJepaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -401,7 +402,7 @@ class ImageGPTMLP(nn.Module):
return hidden_states return hidden_states
class ImageGPTBlock(nn.Module): class ImageGPTBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -719,29 +720,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
outputs = self._gradient_checkpointing_func( hidden_states,
block.__call__, layer_past,
hidden_states, attention_mask,
None, head_mask[i],
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
head_mask[i], encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, use_cache=use_cache,
encoder_attention_mask, output_attentions=output_attentions,
use_cache, )
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa,
) )
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -744,7 +745,7 @@ class InformerProbSparseAttention(nn.Module):
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
class InformerConvLayer(nn.Module): class InformerConvLayer(GradientCheckpointingLayer):
def __init__(self, c_in): def __init__(self, c_in):
super().__init__() super().__init__()
self.downConv = nn.Conv1d( self.downConv = nn.Conv1d(
@ -767,7 +768,7 @@ class InformerConvLayer(nn.Module):
return x return x
class InformerEncoderLayer(nn.Module): class InformerEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InformerConfig): def __init__(self, config: InformerConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -845,7 +846,7 @@ class InformerEncoderLayer(nn.Module):
return outputs return outputs
class InformerDecoderLayer(nn.Module): class InformerDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1086,27 +1087,15 @@ class InformerEncoder(InformerPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions, if conv_layer is not None:
) output = conv_layer(layer_outputs[0])
if conv_layer is not None: layer_outputs = (output,) + layer_outputs[1:]
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -1299,35 +1288,18 @@ class InformerDecoder(InformerPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask,
hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_values,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa,
) )
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
) )
@ -433,7 +434,7 @@ class InformerProbSparseAttention(nn.Module):
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
class InformerConvLayer(nn.Module): class InformerConvLayer(GradientCheckpointingLayer):
def __init__(self, c_in): def __init__(self, c_in):
super().__init__() super().__init__()
self.downConv = nn.Conv1d( self.downConv = nn.Conv1d(
@ -610,27 +611,15 @@ class InformerEncoder(TimeSeriesTransformerEncoder):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, output_attentions=output_attentions,
(head_mask[idx] if head_mask is not None else None), )
output_attentions, if conv_layer is not None:
) output = conv_layer(layer_outputs[0])
if conv_layer is not None: layer_outputs = (output,) + layer_outputs[1:]
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -427,7 +427,7 @@ class InstructBlipEncoder(nn.Module):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@ -889,11 +889,11 @@ class InstructBlipQFormerEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value, past_key_value=past_key_value,
output_attentions, output_attentions=output_attentions,
query_length, query_length=query_length,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -356,7 +356,7 @@ class InstructBlipVideoEncoder(nn.Module):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@ -750,11 +750,11 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value, past_key_value=past_key_value,
output_attentions, output_attentions=output_attentions,
query_length, query_length=query_length,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -383,7 +384,7 @@ class InternVLVisionMLP(nn.Module):
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
class InternVLVisionLayer(nn.Module): class InternVLVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: InternVLVisionConfig) -> None: def __init__(self, config: InternVLVisionConfig) -> None:
@ -452,12 +453,7 @@ class InternVLVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, output_attentions
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ import torch.utils.checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -334,7 +335,7 @@ class InternVLVisionMLP(CLIPMLP):
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
class InternVLVisionLayer(nn.Module): class InternVLVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config: InternVLVisionConfig) -> None: def __init__(self, config: InternVLVisionConfig) -> None:
@ -403,12 +404,7 @@ class InternVLVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(hidden_states, output_attentions)
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, output_attentions
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
@ -763,7 +764,7 @@ JETMOE_ATTENTION_CLASSES = {
} }
class JetMoeBlock(nn.Module): class JetMoeBlock(GradientCheckpointingLayer):
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
""" """
Initialize the JetMoeBlock module. Initialize the JetMoeBlock module.
@ -967,28 +968,15 @@ class JetMoeModel(JetMoePreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=causal_mask,
hidden_states, position_ids=position_ids,
position_ids, past_key_value=past_key_values,
past_key_values, output_attentions=output_attentions,
causal_mask, output_router_logits=output_router_logits,
output_attentions, use_cache=use_cache,
output_router_logits, )
use_cache,
use_reentrant=False,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -404,7 +405,7 @@ class Kosmos2VisionMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
class Kosmos2VisionEncoderLayer(nn.Module): class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Kosmos2VisionConfig): def __init__(self, config: Kosmos2VisionConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -521,21 +522,12 @@ class Kosmos2VisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask,
hidden_states, causal_attention_mask,
attention_mask, output_attentions=output_attentions,
causal_attention_mask, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -840,7 +832,7 @@ class Kosmos2TextFFN(nn.Module):
return hidden_states return hidden_states
class Kosmos2TextBlock(nn.Module): class Kosmos2TextBlock(GradientCheckpointingLayer):
def __init__(self, config: Kosmos2TextConfig): def __init__(self, config: Kosmos2TextConfig):
super().__init__() super().__init__()
self.embed_dim = config.embed_dim self.embed_dim = config.embed_dim
@ -1138,34 +1130,18 @@ class Kosmos2TextTransformer(nn.Module):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask,
hidden_states, encoder_hidden_states,
attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_value,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, **kwargs,
use_cache, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -358,7 +359,7 @@ class LayoutLMOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
class LayoutLMLayer(nn.Module): class LayoutLMLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -484,27 +485,15 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_head_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, past_key_value=past_key_value,
encoder_attention_mask, output_attentions=output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -261,7 +262,7 @@ class LayoutLMv2Output(nn.Module):
return hidden_states return hidden_states
class LayoutLMv2Layer(nn.Module): class LayoutLMv2Layer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -436,25 +437,14 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, output_attentions,
layer_head_mask, rel_pos=rel_pos,
output_attentions, rel_2d_pos=rel_2d_pos,
rel_pos=rel_pos, )
rel_2d_pos=rel_2d_pos,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
@ -358,7 +359,7 @@ class LayoutLMv3Attention(nn.Module):
# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
class LayoutLMv3Layer(nn.Module): class LayoutLMv3Layer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -514,25 +515,14 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, output_attentions,
layer_head_mask, rel_pos=rel_pos,
output_attentions, rel_2d_pos=rel_2d_pos,
rel_pos, )
rel_2d_pos,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, logging
@ -900,7 +901,7 @@ class LEDDecoderAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
class LEDEncoderLayer(nn.Module): class LEDEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LEDConfig, layer_id: int): def __init__(self, config: LEDConfig, layer_id: int):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -962,7 +963,7 @@ class LEDEncoderLayer(nn.Module):
return (hidden_states,) + attn_outputs[1:] return (hidden_states,) + attn_outputs[1:]
class LEDDecoderLayer(nn.Module): class LEDDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -1680,27 +1681,15 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if self.gradient_checkpointing and self.training: layer_outputs = encoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
encoder_layer.__call__, attention_mask=attention_mask,
hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
attention_mask, is_index_masked=is_index_masked,
head_mask[idx] if head_mask is not None else None, is_index_global_attn=is_index_global_attn,
is_index_masked, is_global_attn=is_global_attn,
is_index_global_attn, output_attentions=output_attentions,
is_global_attn, )
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
@ -1943,33 +1932,17 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, combined_attention_mask,
hidden_states, encoder_hidden_states, # as a positional argument for gradient checkpointing
combined_attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_attention_mask, cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
head_mask[idx] if head_mask is not None else None, past_key_value=past_key_value,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, output_attentions=output_attentions,
None, use_cache=use_cache,
output_attentions, )
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

Some files were not shown because too many files have changed in this diff Show More