Apply GradientCheckpointingLayer to the whole repo (#38913)

* first batch (4)

* align

* altclip

* beit

* bert

* yolos

* dino, pvt_v2

* bark, bart, bert_generation

* big_bird, biogpt

* blnderbot, bloom

* bridgetower

* camambert, canine, chameleon

* chinese clip, clap, clip

* codegen, conditional detr, convbert

* dab_detr, data2vec

* dbrx, deberta

* deberta, decicion_tranformer, deformable_detr

* deit, deta, mctct

* detr, dinov2, distilbert

* donut, dpt, electra

* ernie, esm, falcon

* flava, fnet, falcon_mamba

* focalnet, git, gpt2

* gpt - bigcode, neo, neox

* gptj, groupvit

* idefics2, idefics3

* ijepa, imagegpt, internvl

* jetmoe, kosmos2, layoutlm

* layoutlm2-3, led

* lilt, longformer, longt5, luke

* m2m, mamba1-2

* marian, markuplm, mask2former

* maskformer

* mbart, megatron_bert, mimi

* mixtral, mlcd

* mobilevit1-2, modernbert

* moshi, mpt, mra

* mt5, musicgen

* mvp, nemotron

* nllb_moe

* nystromformer, omdet_turbo

* opt, owlvit, owlv2

* pegasus, pegasus_x, presimmon

* phimoe, pix2struct, pixtral

* plbart, pop2piano, prophetnet

* qwen2*

* qwen2, qwen3 moe,  rec gemma

* rembert

* roberta

* roberta prelayernorm

* roc_bert, roformer, rwkv

* sam, sam_hq

* seggpt, smolvlm, speech_to_text

* splinter, stablelm, swin

* swin2sr, switch_transformer, t5, table_transformer

* tapas, time_series_tranformer, timesformer

* trocr, tvp, umt5

* videomae, vilt, visual_bert

* vit, vit_mae, vit_msn

* vitpose_backbone, vits, vivit

* whisper. x_clip, xglm

* xlm_roberta, xmod

* yoso

* zamba

* vitdet, wav2vec2, wav2vec2_bert

* unispeech, wav2vec_conformer

* wavlm

* speecht5

* swinv2

* sew / _d

* seamless_mt4 / _v2

* deprecated models update

* bros

* gemma2, gemma3

* got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer

* fixup

* Add use_cache=False and past_key_value=None to  GradientCheckpointingLayer

* fixup

* fix prophetnet

* fix bigbird_pegasus

* fix blenderbot

* fix mbart

* fix mvp

* fix zamba2

* fix bart

* fix blenderbot_small

* fix codegen

* Update gradient checkpointing layer to support more past_key_values arg names

* fix data2vec vision

* fix deformable_detr

* fix gptj

* fix led

* fix m2m_100

* add comment

* fix nnlb_moe

* Fix pegasus_x

* fix plbart

* udop

* fix-copies: beit, wav2vec2

* fix gpt_bigcode

* fixup

* fix t5

* fix switch_transformers

* fix longt5

* fix mt5

* update tapas

* fix blip2

* update blip

* fix musicgen

* fix gpt2, trocr

* fix copies

* !!! Revert zamba, mllama

* update autoformer

* update bros

* update args / kwargs for BERT and copies

* 2nd round of updates

* update conditional detr

* Pass encoder_hidden_states as positional arg

* Update to pass encoder_decoder_position_bias as positional arg

* fixup

* biogpt modular

* modular gemma2

* modular gemma3

* modular gpt_neox

* modular informer

* modular internvl

* modular mixtral

* modular mlcd

* modular modernbert

* modular phi

* modular qwen2_5_omni

* modular qwen2_5_vl

* modular sam_hq

* modular sew

* wav2vec2_bert

* modular wav2vec2_conformer

* modular wavlm

* fixup

* Update by modular instructblipvideo

* modular data2vec_audio

* nit modular mistral

* apply modular minimax

* fix modular moonshine

* revert zamba2

* fix mask2former

* refactor idefics
This commit is contained in:
Pavel Iakubovskii 2025-06-23 13:24:48 +01:00 committed by GitHub
parent 07aab1af1e
commit 84d19be41e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
224 changed files with 2513 additions and 5280 deletions

View File

@ -16,6 +16,11 @@ from functools import partial
import torch.nn as nn
from transformers.utils import logging
logger = logging.get_logger(__name__)
class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing.
@ -44,5 +49,35 @@ class GradientCheckpointingLayer(nn.Module):
def __call__(self, *args, **kwargs):
if self.gradient_checkpointing and self.training:
do_warn = False
layer_name = self.__class__.__name__
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
if "use_cache" in kwargs and kwargs["use_cache"]:
kwargs["use_cache"] = False
message += " `use_cache=False`,"
do_warn = True
# different names for the same thing in different layers
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
kwargs["past_key_value"] = None
message += " `past_key_value=None`,"
do_warn = True
if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
kwargs["past_key_values"] = None
message += " `past_key_values=None`,"
do_warn = True
if "layer_past" in kwargs and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
message += " `layer_past=None`,"
do_warn = True
# warn if anything was changed
if do_warn:
message = message.rstrip(",") + "."
logger.warning(message)
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
return super().__call__(*args, **kwargs)

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPastAndCrossAttentions,
@ -827,7 +828,7 @@ class AlignTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
class AlignTextLayer(nn.Module):
class AlignTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -953,27 +954,15 @@ class AlignTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -418,7 +419,7 @@ class AltRobertaOutput(nn.Module):
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
class AltRobertaLayer(nn.Module):
class AltRobertaLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -544,27 +545,15 @@ class AltRobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -732,7 +721,7 @@ class AltCLIPMLP(nn.Module):
return hidden_states
class AltCLIPEncoderLayer(nn.Module):
class AltCLIPEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AltCLIPConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -848,21 +837,12 @@ class AltCLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -282,7 +283,7 @@ class ASTOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module):
class ASTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ASTConfig) -> None:
@ -349,16 +350,7 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:

View File

@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
@ -670,7 +671,7 @@ class AutoformerAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
class AutoformerEncoderLayer(nn.Module):
class AutoformerEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig):
super().__init__()
self.embed_dim = config.d_model
@ -744,7 +745,7 @@ class AutoformerEncoderLayer(nn.Module):
return outputs
class AutoformerDecoderLayer(nn.Module):
class AutoformerDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AutoformerConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1042,21 +1043,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1186,6 +1178,12 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
input_shape = inputs_embeds.size()[:-1]
# expand encoder attention mask
@ -1228,38 +1226,17 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
(hidden_states, residual_trend) = layer_outputs[0]
trend = trend + residual_trend

View File

@ -31,6 +31,7 @@ from ...generation.logits_process import (
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
@ -309,7 +310,7 @@ class BarkMLP(nn.Module):
return hidden_states
class BarkBlock(nn.Module):
class BarkBlock(GradientCheckpointingLayer):
def __init__(self, config, is_causal=False):
super().__init__()
@ -606,25 +607,14 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_values=past_layer_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
past_key_values=past_layer_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]

View File

@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -270,7 +271,7 @@ class BartAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class BartEncoderLayer(nn.Module):
class BartEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -341,7 +342,7 @@ class BartEncoderLayer(nn.Module):
return outputs
class BartDecoderLayer(nn.Module):
class BartDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -875,21 +876,12 @@ class BartEncoder(BartPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1137,35 +1129,18 @@ class BartDecoder(BartPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -26,6 +26,7 @@ from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutput,
@ -497,7 +498,7 @@ class BeitOutput(nn.Module):
return hidden_states
class BeitLayer(nn.Module):
class BeitLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
@ -525,7 +526,7 @@ class BeitLayer(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[tuple[int]] = None,
resolution: Optional[tuple[int, int]] = None,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
@ -695,25 +696,14 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
else:
layer_outputs = layer_module(
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
layer_outputs = layer_module(
hidden_states,
head_mask=layer_head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)
hidden_states = layer_outputs[0]

View File

@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -522,7 +523,7 @@ class BertOutput(nn.Module):
return hidden_states
class BertLayer(nn.Module):
class BertLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -647,27 +648,15 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
@ -275,7 +276,7 @@ class BertGenerationOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration
class BertGenerationLayer(nn.Module):
class BertGenerationLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -401,27 +402,15 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -1419,7 +1420,7 @@ class BigBirdOutput(nn.Module):
return hidden_states
class BigBirdLayer(nn.Module):
class BigBirdLayer(GradientCheckpointingLayer):
def __init__(self, config, seed=None):
super().__init__()
self.config = config
@ -1593,35 +1594,19 @@ class BigBirdEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
band_mask,
from_mask,
to_mask,
blocked_encoder_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
band_mask,
from_mask,
to_mask,
blocked_encoder_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
band_mask,
from_mask,
to_mask,
blocked_encoder_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -1333,7 +1334,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class BigBirdPegasusEncoderLayer(nn.Module):
class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BigBirdPegasusConfig, seed=None):
super().__init__()
self.attention_type = config.attention_type
@ -1420,7 +1421,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
self.self_attn.set_attention_type(value)
class BigBirdPegasusDecoderLayer(nn.Module):
class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -1947,31 +1948,17 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
band_mask,
from_mask,
to_mask,
blocked_encoder_mask,
blocked_encoder_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
band_mask=band_mask,
from_mask=from_mask,
to_mask=to_mask,
from_blocked_mask=blocked_encoder_mask,
to_blocked_mask=blocked_encoder_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
band_mask=band_mask,
from_mask=from_mask,
to_mask=to_mask,
from_blocked_mask=blocked_encoder_mask,
to_blocked_mask=blocked_encoder_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -2297,35 +2284,18 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -20,7 +20,6 @@
# limitations under the License.
import math
from functools import partial
from typing import Callable, Optional, Union
import torch
@ -32,6 +31,7 @@ from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -248,7 +248,7 @@ class BioGptAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class BioGptDecoderLayer(nn.Module):
class BioGptDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.hidden_size
@ -646,30 +646,17 @@ class BioGptModel(BioGptPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
position_ids,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -15,7 +15,6 @@
"""PyTorch BioGPT model."""
import math
from functools import partial
from typing import Optional, Union
import torch
@ -473,30 +472,17 @@ class BioGptModel(BioGptPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
position_ids,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -270,7 +271,7 @@ class BlenderbotAttention(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotEncoderLayer(nn.Module):
class BlenderbotEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
@ -339,7 +340,7 @@ class BlenderbotEncoderLayer(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotDecoderLayer(nn.Module):
class BlenderbotDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -825,21 +826,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1090,35 +1082,18 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
causal_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -254,7 +255,7 @@ class BlenderbotSmallAttention(nn.Module):
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallEncoderLayer(nn.Module):
class BlenderbotSmallEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -326,7 +327,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module):
class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -812,21 +813,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1073,35 +1065,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
causal_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -552,7 +552,7 @@ class BlipEncoder(nn.Module):
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
)

View File

@ -531,7 +531,7 @@ class Blip2Encoder(nn.Module):
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
@ -992,11 +992,11 @@ class Blip2QFormerEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
query_length=query_length,
)
hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from torch.nn import functional as F
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -366,7 +367,7 @@ class BloomMLP(nn.Module):
return output
class BloomBlock(nn.Module):
class BloomBlock(GradientCheckpointingLayer):
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
super().__init__()
hidden_size = config.hidden_size
@ -605,29 +606,16 @@ class BloomModel(BloomPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
past_key_values,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache:

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -662,7 +663,7 @@ class BridgeTowerBertCrossLayer(nn.Module):
return layer_output
class BridgeTowerTextLayer(nn.Module):
class BridgeTowerTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -788,27 +789,15 @@ class BridgeTowerTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -428,7 +429,7 @@ class BrosOutput(nn.Module):
return hidden_states
class BrosLayer(nn.Module):
class BrosLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -550,34 +551,16 @@ class BrosEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
bbox_pos_emb,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states=hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=layer_head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
layer_outputs = layer_module(
hidden_states,
bbox_pos_emb,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -478,7 +479,7 @@ class CamembertOutput(nn.Module):
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert
class CamembertLayer(nn.Module):
class CamembertLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -604,27 +605,15 @@ class CamembertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
ModelOutput,
@ -672,7 +673,7 @@ class CanineOutput(nn.Module):
return hidden_states
class CanineLayer(nn.Module):
class CanineLayer(GradientCheckpointingLayer):
def __init__(
self,
config,
@ -779,16 +780,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -383,7 +384,7 @@ class ChameleonAttention(nn.Module):
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
class ChameleonDecoderLayer(nn.Module):
class ChameleonDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -458,7 +459,7 @@ class ChameleonDecoderLayer(nn.Module):
return outputs
class ChameleonSwinDecoderLayer(nn.Module):
class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -1011,28 +1012,16 @@ class ChameleonModel(ChameleonPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -577,7 +578,7 @@ class ChineseCLIPVisionMLP(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
class ChineseCLIPTextLayer(nn.Module):
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -663,7 +664,7 @@ class ChineseCLIPTextLayer(nn.Module):
return layer_output
class ChineseCLIPVisionLayer(nn.Module):
class ChineseCLIPVisionLayer(GradientCheckpointingLayer):
def __init__(self, config: ChineseCLIPConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -816,27 +817,15 @@ class ChineseCLIPTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -920,17 +909,10 @@ class ChineseCLIPVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
@ -691,7 +692,7 @@ class ClapAudioLayer(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio
class ClapAudioStage(nn.Module):
class ClapAudioStage(GradientCheckpointingLayer):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
@ -928,14 +929,9 @@ class ClapAudioEncoder(nn.Module):
input_dimensions = self.input_resolutions[i]
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
hidden_states = layer_outputs[0]
@ -1355,7 +1351,7 @@ class ClapTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
class ClapTextLayer(nn.Module):
class ClapTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -1481,27 +1477,15 @@ class ClapTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
@ -393,7 +394,7 @@ class CLIPMLP(nn.Module):
return hidden_states
class CLIPEncoderLayer(nn.Module):
class CLIPEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
@ -575,21 +576,12 @@ class CLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -374,7 +375,7 @@ class CLIPSegMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
class CLIPSegEncoderLayer(nn.Module):
class CLIPSegEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: CLIPSegConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -539,22 +540,12 @@ class CLIPSegEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:

View File

@ -24,6 +24,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -245,7 +246,7 @@ class CodeGenMLP(nn.Module):
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(nn.Module):
class CodeGenBlock(GradientCheckpointingLayer):
# Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
@ -437,29 +438,16 @@ class CodeGenModel(CodeGenPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
causal_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends
@ -827,7 +828,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
return outputs
class ConditionalDetrDecoderLayer(nn.Module):
class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ConditionalDetrConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1297,31 +1298,18 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
pos_transformation = self.query_scale(hidden_states)
# apply transformation
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
None,
object_queries,
query_position_embeddings,
query_sine_embed,
encoder_hidden_states,
encoder_attention_mask,
None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
query_sine_embed=query_sine_embed,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
is_first=(idx == 0),
)
layer_outputs = decoder_layer(
hidden_states,
None, # attention_mask
object_queries,
query_position_embeddings,
query_sine_embed,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
is_first=(idx == 0),
)
hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
MaskedLMOutput,
@ -532,7 +533,7 @@ class ConvBertOutput(nn.Module):
return hidden_states
class ConvBertLayer(nn.Module):
class ConvBertLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -620,25 +621,14 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -702,7 +703,7 @@ class DabDetrDecoderLayerFFN(nn.Module):
# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig
class DabDetrEncoderLayer(nn.Module):
class DabDetrEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DabDetrConfig):
super().__init__()
self.hidden_size = config.hidden_size
@ -764,7 +765,7 @@ class DabDetrEncoderLayer(nn.Module):
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr
class DabDetrDecoderLayer(nn.Module):
class DabDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DabDetrConfig, is_first: bool = False):
super().__init__()
self.self_attn = DabDetrDecoderLayerSelfAttention(config)
@ -976,21 +977,12 @@ class DabDetrEncoder(DabDetrPreTrainedModel):
# we add object_queries * pos_scaler as extra input to the encoder_layer
scaled_object_queries = object_queries * pos_scales
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
scaled_object_queries,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
object_queries=scaled_object_queries,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
object_queries=scaled_object_queries,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1138,29 +1130,16 @@ class DabDetrDecoder(DabDetrPreTrainedModel):
reference_anchor_size[..., 1] / obj_center[..., 3]
).unsqueeze(-1)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
None,
object_queries,
query_pos,
query_sine_embed,
encoder_hidden_states,
memory_key_padding_mask,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_pos,
query_sine_embed=query_sine_embed,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=memory_key_padding_mask,
output_attentions=output_attentions,
)
layer_outputs = decoder_layer(
hidden_states,
None, # attention_mask
object_queries,
query_pos,
query_sine_embed,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=memory_key_padding_mask,
output_attentions=output_attentions,
)
# iter update
hidden_states = layer_outputs[0]

View File

@ -33,6 +33,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@ -51,7 +52,7 @@ if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
class Data2VecAudioConvLayer(nn.Module):
class Data2VecAudioConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -155,13 +156,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
hidden_states = conv_layer(hidden_states)
return hidden_states
@ -357,7 +352,7 @@ class Data2VecAudioFeedForward(nn.Module):
return hidden_states
class Data2VecAudioEncoderLayer(nn.Module):
class Data2VecAudioEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = Data2VecAudioAttention(
@ -441,17 +436,9 @@ class Data2VecAudioEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:

View File

@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -375,7 +376,7 @@ class Data2VecTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
class Data2VecTextLayer(nn.Module):
class Data2VecTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -501,27 +502,15 @@ class Data2VecTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
@ -497,7 +498,7 @@ class Data2VecVisionOutput(nn.Module):
# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
class Data2VecVisionLayer(nn.Module):
class Data2VecVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(
@ -527,7 +528,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[tuple[int]] = None,
resolution: Optional[tuple[int, int]] = None,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
@ -699,25 +700,14 @@ class Data2VecVisionEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
else:
layer_outputs = layer_module(
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
layer_outputs = layer_module(
hidden_states,
head_mask=layer_head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)
hidden_states = layer_outputs[0]

View File

@ -20,6 +20,7 @@ import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ..wav2vec2.modeling_wav2vec2 import (
@ -38,7 +39,7 @@ from ..wav2vec2.modeling_wav2vec2 import (
from .configuration_data2vec_audio import Data2VecAudioConfig
class Data2VecAudioConvLayer(nn.Module):
class Data2VecAudioConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1

View File

@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
@ -724,7 +725,7 @@ class DbrxFFN(nn.Module):
return out, weights
class DbrxBlock(nn.Module):
class DbrxBlock(GradientCheckpointingLayer):
def __init__(self, config: DbrxConfig, block_idx: int):
super().__init__()
self.hidden_size = config.d_model
@ -947,29 +948,16 @@ class DbrxModel(DbrxPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
block_outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cache_position,
)
else:
block_outputs = block(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
block_outputs = block(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = block_outputs[0]

View File

@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
@ -492,7 +493,7 @@ class DebertaOutput(nn.Module):
return hidden_states
class DebertaLayer(nn.Module):
class DebertaLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = DebertaAttention(config)
@ -580,25 +581,14 @@ class DebertaEncoder(nn.Module):
rel_embeddings = self.get_rel_embedding()
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
hidden_states, att_m = self._gradient_checkpointing_func(
layer_module.__call__,
next_kv,
attention_mask,
query_states,
relative_pos,
rel_embeddings,
output_attentions,
)
else:
hidden_states, att_m = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
hidden_states, att_m = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
@ -418,7 +419,7 @@ class DebertaV2Output(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
class DebertaV2Layer(nn.Module):
class DebertaV2Layer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = DebertaV2Attention(config)
@ -655,25 +656,14 @@ class DebertaV2Encoder(nn.Module):
next_kv = hidden_states
rel_embeddings = self.get_rel_embedding()
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
output_states, attn_weights = self._gradient_checkpointing_func(
layer_module.__call__,
next_kv,
attention_mask,
query_states,
relative_pos,
rel_embeddings,
output_attentions,
)
else:
output_states, attn_weights = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
output_states, attn_weights = layer_module(
next_kv,
attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attn_weights,)

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
@ -360,7 +361,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module):
class DecisionTransformerGPT2Block(GradientCheckpointingLayer):
# Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
@ -654,31 +655,17 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
past_key_values if not (self.gradient_checkpointing and self.training) else None,
cache_position,
attention_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]

View File

@ -27,6 +27,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
@ -759,7 +760,7 @@ class DeformableDetrMultiheadAttention(nn.Module):
return attn_output, attn_weights_reshaped
class DeformableDetrEncoderLayer(nn.Module):
class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DeformableDetrConfig):
super().__init__()
self.embed_dim = config.d_model
@ -848,7 +849,7 @@ class DeformableDetrEncoderLayer(nn.Module):
return outputs
class DeformableDetrDecoderLayer(nn.Module):
class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DeformableDetrConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1126,29 +1127,16 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
reference_points,
spatial_shapes,
spatial_shapes_list,
level_start_index,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1273,31 +1261,17 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
spatial_shapes_list,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
spatial_shapes_list,
level_start_index,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
@ -347,7 +348,7 @@ class DeiTOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module):
class DeiTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: DeiTConfig) -> None:
@ -414,15 +415,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -39,6 +39,7 @@ from ....file_utils import (
replace_return_docstrings,
)
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput
from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import meshgrid
@ -909,7 +910,7 @@ class DetaEncoderLayer(nn.Module):
return outputs
class DetaDecoderLayer(nn.Module):
class DetaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetaConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1341,29 +1342,16 @@ class DetaDecoder(DetaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
from ....integrations.fsdp import is_fsdp_managed_module
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
from ....modeling_utils import (
PreTrainedModel,
@ -377,7 +378,7 @@ class MCTCTOutput(nn.Module):
return hidden_states
class MCTCTLayer(nn.Module):
class MCTCTLayer(GradientCheckpointingLayer):
def __init__(self, config: MCTCTConfig):
super().__init__()
@ -591,20 +592,11 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -438,7 +439,7 @@ class NezhaOutput(nn.Module):
return hidden_states
class NezhaLayer(nn.Module):
class NezhaLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -563,27 +564,15 @@ class NezhaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN
from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ....modeling_utils import PreTrainedModel
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@ -339,7 +340,7 @@ class OpenLlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class OpenLlamaDecoderLayer(nn.Module):
class OpenLlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: OpenLlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
@ -631,25 +632,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -452,7 +453,7 @@ class QDQBertOutput(nn.Module):
# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
class QDQBertLayer(nn.Module):
class QDQBertLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.seq_len_dim = 1
@ -568,32 +569,15 @@ class QDQBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -447,7 +448,7 @@ class RealmOutput(nn.Module):
return hidden_states
class RealmLayer(nn.Module):
class RealmLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -572,27 +573,15 @@ class RealmEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ....activations import ACT2FN
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ....modeling_utils import PreTrainedModel
from ....utils import add_start_docstrings, logging, replace_return_docstrings
@ -263,7 +264,7 @@ class Speech2Text2Attention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
class Speech2Text2DecoderLayer(nn.Module):
class Speech2Text2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Speech2Text2Config):
super().__init__()
self.embed_dim = config.d_model
@ -612,31 +613,17 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_utils import PreTrainedModel
from ....utils import (
ModelOutput,
@ -346,7 +347,7 @@ class CausalSelfAttention(nn.Module):
return outputs
class Block(nn.Module):
class Block(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
@ -540,16 +541,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(hidden_states, layer_past, use_cache, output_attentions)
outputs = block(hidden_states, layer_past, use_cache, output_attentions)
hidden_states = outputs[0]
if use_cache is True:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -483,7 +484,7 @@ class TvltOutput(nn.Module):
return hidden_states
class TvltLayer(nn.Module):
class TvltLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
@ -546,16 +547,7 @@ class TvltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
@ -853,15 +845,7 @@ class TvltDecoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -390,7 +391,7 @@ VIT_HYBRID_ATTENTION_CLASSES = {
}
class ViTHybridLayer(nn.Module):
class ViTHybridLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTHybridConfig) -> None:
@ -457,15 +458,7 @@ class ViTHybridEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch import Tensor, nn
from torch.nn import LayerNorm
from ....activations import ACT2FN
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput
from ....modeling_utils import PreTrainedModel
from ....utils import (
@ -1090,7 +1091,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
return predict_relative_pos_embeddings
class XLMProphetNetEncoderLayer(nn.Module):
class XLMProphetNetEncoderLayer(GradientCheckpointingLayer):
"""
Encoder block for XLMProphetnet
"""
@ -1133,7 +1134,7 @@ class XLMProphetNetEncoderLayer(nn.Module):
return outputs
class XLMProphetNetDecoderLayer(nn.Module):
class XLMProphetNetDecoderLayer(GradientCheckpointingLayer):
"""
Decoder block for XLMProphetnet
"""
@ -1320,21 +1321,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1554,41 +1546,21 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
extended_attention_mask,
encoder_hidden_states,
extended_encoder_attention_mask,
(head_mask[idx] if head_mask is not None else None),
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
extended_predict_attention_mask,
main_relative_position_buckets,
predict_relative_position_buckets,
position_ids,
None,
use_cache,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -677,7 +678,7 @@ class DetrEncoderLayer(nn.Module):
return outputs
class DetrDecoderLayer(nn.Module):
class DetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetrConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1045,25 +1046,15 @@ class DetrDecoder(DetrPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = decoder_layer(
hidden_states,
combined_attention_mask,
object_queries,
query_position_embeddings,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -382,7 +383,7 @@ class Dinov2SwiGLUFFN(nn.Module):
return self.weights_out(hidden)
class Dinov2Layer(nn.Module):
class Dinov2Layer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the original implementation."""
def __init__(self, config: Dinov2Config) -> None:
@ -458,15 +459,7 @@ class Dinov2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -28,6 +28,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -399,7 +400,7 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
return self.weights_out(hidden)
class Dinov2WithRegistersLayer(nn.Module):
class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the original implementation."""
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
@ -476,15 +477,7 @@ class Dinov2WithRegistersEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ from ...configuration_utils import PretrainedConfig
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
@ -441,7 +442,7 @@ DISTILBERT_ATTENTION_CLASSES = {
}
class TransformerBlock(nn.Module):
class TransformerBlock(GradientCheckpointingLayer):
def __init__(self, config: PretrainedConfig):
super().__init__()
@ -537,21 +538,12 @@ class Transformer(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_state,
attn_mask,
head_mask[i],
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_state,
attn_mask,
head_mask[i],
output_attentions,
)
layer_outputs = layer_module(
hidden_state,
attn_mask,
head_mask[i],
output_attentions,
)
hidden_state = layer_outputs[-1]

View File

@ -27,6 +27,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -706,7 +707,7 @@ class DonutSwinLayer(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module):
class DonutSwinStage(GradientCheckpointingLayer):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
@ -816,19 +817,9 @@ class DonutSwinEncoder(nn.Module):
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]

View File

@ -29,6 +29,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -469,7 +470,7 @@ class DPTViTOutput(nn.Module):
# copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput
class DPTViTLayer(nn.Module):
class DPTViTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: DPTConfig) -> None:
@ -536,15 +537,7 @@ class DPTViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions,
@ -436,7 +437,7 @@ class ElectraOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
class ElectraLayer(nn.Module):
class ElectraLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -562,27 +563,15 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -361,7 +362,7 @@ class ErnieOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie
class ErnieLayer(nn.Module):
class ErnieLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -487,27 +488,15 @@ class ErnieEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -599,7 +600,7 @@ class EsmOutput(nn.Module):
return hidden_states
class EsmLayer(nn.Module):
class EsmLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -725,27 +726,15 @@ class EsmEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
)
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -556,7 +557,7 @@ FALCON_ATTENTION_CLASSES = {
}
class FalconDecoderLayer(nn.Module):
class FalconDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
@ -836,33 +837,18 @@ class FalconModel(FalconPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
past_key_values,
use_cache,
output_attentions,
cache_position,
position_embeddings,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = outputs[0]
if use_cache is True:

View File

@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import MambaCache
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
@ -405,7 +406,7 @@ class FalconMambaRMSNorm(nn.Module):
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
class FalconMambaBlock(nn.Module):
class FalconMambaBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
@ -620,17 +621,12 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -577,7 +578,7 @@ class FlavaOutput(nn.Module):
return hidden_states
class FlavaLayer(nn.Module):
class FlavaLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: FlavaPossibleConfigs) -> None:
@ -648,16 +649,7 @@ class FlavaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ if is_scipy_available():
from scipy import linalg
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
@ -235,7 +236,7 @@ class FNetOutput(nn.Module):
return hidden_states
class FNetLayer(nn.Module):
class FNetLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -276,10 +277,7 @@ class FNetEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
else:
layer_outputs = layer_module(hidden_states)
layer_outputs = layer_module(hidden_states)
hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
@ -455,7 +456,7 @@ class FocalNetLayer(nn.Module):
return hidden_state
class FocalNetStage(nn.Module):
class FocalNetStage(GradientCheckpointingLayer):
def __init__(self, config, index, input_resolution):
super().__init__()
@ -560,14 +561,7 @@ class FocalNetEncoder(nn.Module):
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, stage_module in enumerate(self.stages):
if self.gradient_checkpointing and self.training:
stage_outputs = self._gradient_checkpointing_func(
stage_module.__call__,
hidden_states,
input_dimensions,
)
else:
stage_outputs = stage_module(hidden_states, input_dimensions)
stage_outputs = stage_module(hidden_states, input_dimensions)
hidden_states = stage_outputs[0]
hidden_states_before_downsampling = stage_outputs[1]

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Union
import torch
@ -30,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -238,7 +238,7 @@ class Gemma2Attention(nn.Module):
return attn_output, attn_weights
class Gemma2DecoderLayer(nn.Module):
class Gemma2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -466,30 +466,17 @@ class Gemma2Model(Gemma2PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask_mapping[decoder_layer.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Union
import torch
@ -25,6 +24,7 @@ from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
@ -303,7 +303,7 @@ class Gemma2Attention(GemmaAttention):
return attn_output, attn_weights
class Gemma2DecoderLayer(nn.Module):
class Gemma2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -449,30 +449,17 @@ class Gemma2Model(GemmaModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask_mapping[decoder_layer.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -22,7 +22,6 @@
import copy
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import Optional, Union
import torch
@ -34,6 +33,7 @@ from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -364,7 +364,7 @@ class Gemma3Attention(nn.Module):
return attn_output, attn_weights
class Gemma3DecoderLayer(nn.Module):
class Gemma3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__()
self.config = config
@ -581,32 +581,18 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings_global,
position_embeddings_local,
causal_mask_mapping[decoder_layer.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -16,7 +16,6 @@
import copy
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional, Union
import torch
@ -27,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@ -443,7 +443,7 @@ class Gemma3Attention(Gemma2Attention):
return attn_output, attn_weights
class Gemma3DecoderLayer(nn.Module):
class Gemma3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__()
self.config = config
@ -632,32 +632,18 @@ class Gemma3TextModel(Gemma2Model):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings_global,
position_embeddings_local,
causal_mask_mapping[decoder_layer.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
@ -343,7 +344,7 @@ class GitOutput(nn.Module):
return hidden_states
class GitLayer(nn.Module):
class GitLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -441,24 +442,14 @@ class GitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
past_key_values,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
past_key_values,
output_attentions,
pixel_values_present,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
past_key_values,
output_attentions,
pixel_values_present,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -723,7 +714,7 @@ class GitVisionAttention(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
class GitVisionEncoderLayer(nn.Module):
class GitVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GitVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -840,21 +831,12 @@ class GitVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ import torch.nn.functional as F
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
@ -192,7 +193,7 @@ class GotOcr2VisionAttention(nn.Module):
return outputs
class GotOcr2VisionLayer(nn.Module):
class GotOcr2VisionLayer(GradientCheckpointingLayer):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -463,13 +464,7 @@ class GotOcr2VisionEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]

View File

@ -30,6 +30,7 @@ from ...activations import ACT2FN, get_activation
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -368,7 +369,7 @@ class GPT2MLP(nn.Module):
return hidden_states
class GPT2Block(nn.Module):
class GPT2Block(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
@ -922,32 +923,18 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
past_key_values,
cache_position,
causal_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=causal_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
outputs = block(
hidden_states,
past_key_values if not (self.gradient_checkpointing and self.training) else None,
cache_position,
causal_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = outputs[0]

View File

@ -25,6 +25,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -558,7 +559,7 @@ GPTBIGCODE_ATTENTION_CLASSES = {
}
class GPTBigCodeBlock(nn.Module):
class GPTBigCodeBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
@ -759,6 +760,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -891,29 +898,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache:

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
@ -431,7 +432,7 @@ class GPTNeoMLP(nn.Module):
return hidden_states
class GPTNeoBlock(nn.Module):
class GPTNeoBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_id=None):
super().__init__()
hidden_size = config.hidden_size
@ -635,27 +636,15 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
causal_mask,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache:

View File

@ -14,6 +14,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -190,7 +191,7 @@ class GPTNeoXAttention(nn.Module):
return attn_output, attn_weights
class GPTNeoXLayer(nn.Module):
class GPTNeoXLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
@ -415,32 +416,18 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
causal_mask,
position_ids,
head_mask[i],
use_cache,
past_key_values,
output_attentions,
cache_position,
position_embeddings,
)
else:
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = outputs[0]
if output_attentions:

View File

@ -9,6 +9,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -177,7 +178,7 @@ class GPTNeoXAttention(nn.Module):
return attn_output, attn_weights
class GPTNeoXLayer(nn.Module):
class GPTNeoXLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
@ -362,32 +363,18 @@ class GPTNeoXModel(LlamaModel, nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
causal_mask,
position_ids,
head_mask[i],
use_cache,
past_key_values,
output_attentions,
cache_position,
position_embeddings,
)
else:
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = outputs[0]
if output_attentions:

View File

@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -434,7 +435,7 @@ class GPTJMLP(nn.Module):
return hidden_states
class GPTJBlock(nn.Module):
class GPTJBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
@ -738,29 +739,16 @@ class GPTJModel(GPTJPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
causal_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -692,7 +693,7 @@ class GroupViTAttention(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
class GroupViTEncoderLayer(nn.Module):
class GroupViTEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GroupViTConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -906,21 +907,12 @@ class GroupViTTextEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutput,
@ -540,7 +541,7 @@ class HieraLayer(nn.Module):
return (hidden_states, attn_weights)
class HieraStage(nn.Module):
class HieraStage(GradientCheckpointingLayer):
def __init__(
self,
config,
@ -734,12 +735,7 @@ class HieraEncoder(nn.Module):
for i, stage_module in enumerate(self.stages):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
stage_module.__call__, hidden_states, layer_head_mask, output_attentions
)
else:
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -32,6 +32,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -107,7 +108,7 @@ class HubertSamePadLayer(nn.Module):
return hidden_states
class HubertNoLayerNormConvLayer(nn.Module):
class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -128,7 +129,7 @@ class HubertNoLayerNormConvLayer(nn.Module):
return hidden_states
class HubertLayerNormConvLayer(nn.Module):
class HubertLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -155,7 +156,7 @@ class HubertLayerNormConvLayer(nn.Module):
return hidden_states
class HubertGroupNormConvLayer(nn.Module):
class HubertGroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -212,13 +213,7 @@ class HubertFeatureEncoder(nn.Module):
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
hidden_states = conv_layer(hidden_states)
return hidden_states
@ -417,7 +412,7 @@ class HubertFeedForward(nn.Module):
return hidden_states
class HubertEncoderLayer(nn.Module):
class HubertEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = HubertAttention(
@ -501,17 +496,9 @@ class HubertEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:
@ -579,7 +566,7 @@ class HubertAttnAdapterLayer(nn.Module):
return hidden_states
class HubertEncoderLayerStableLayerNorm(nn.Module):
class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = HubertAttention(
@ -675,17 +662,9 @@ class HubertEncoderStableLayerNorm(nn.Module):
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:

View File

@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel
from ...processing_utils import Unpack
@ -668,7 +669,7 @@ class IdeficsAttention(nn.Module):
# this was adapted from LlamaDecoderLayer
class IdeficsDecoderLayer(nn.Module):
class IdeficsDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
@ -749,7 +750,7 @@ class IdeficsDecoderLayer(nn.Module):
return outputs
class IdeficsGatedCrossAttentionLayer(nn.Module):
class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
@ -1185,95 +1186,32 @@ class IdeficsModel(IdeficsPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
def vblock(
main_block,
hidden_states,
attention_mask,
position_ids,
past_key_value,
image_hidden_states,
image_attention_mask,
cross_attention_gate,
output_attentions,
use_cache,
layer_idx,
cross_layer_interval,
gated_cross_attn_layers,
cache_position,
):
# TODO(ls): Add cross attention values to respective lists
if layer_idx % cross_layer_interval == 0:
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
outputs = xblock(
hidden_states,
attention_mask=attention_mask,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
cross_attention_gate=cross_attention_gate,
output_attentions=output_attentions,
use_cache=use_cache,
past_key_value=None, # not implemented
**kwargs,
)
hidden_states = outputs[0]
layer_outputs = main_block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
return layer_outputs
if self.gradient_checkpointing and self.training:
past_key_values = None
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
vblock,
decoder_layer,
# TODO(ls): Add cross attention values to respective lists
if idx % self.cross_layer_interval == 0:
cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval]
outputs = cross_attn_block(
hidden_states,
attention_mask,
position_ids,
past_key_values,
image_hidden_states,
image_attention_mask,
cross_attention_gate,
output_attentions,
use_cache,
idx,
self.cross_layer_interval,
self.gated_cross_attn_layers,
cache_position,
)
else:
layer_outputs = vblock(
decoder_layer,
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
cross_attention_gate=cross_attention_gate,
output_attentions=output_attentions,
use_cache=use_cache,
layer_idx=idx,
cross_layer_interval=self.cross_layer_interval,
gated_cross_attn_layers=self.gated_cross_attn_layers,
cache_position=cache_position,
past_key_value=None, # not implemented
**kwargs,
)
hidden_states = outputs[0]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import (
@ -283,7 +284,7 @@ class IdeficsVisionMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
class IdeficsVisionEncoderLayer(nn.Module):
class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: IdeficsVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -400,21 +401,12 @@ class IdeficsVisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -339,7 +340,7 @@ class Idefics2MultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0]
class Idefics2EncoderLayer(nn.Module):
class Idefics2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -448,19 +449,11 @@ class Idefics2Encoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -300,7 +301,7 @@ class Idefics3SimpleMLP(nn.Module):
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
class Idefics3EncoderLayer(nn.Module):
class Idefics3EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Idefics3VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -409,19 +410,11 @@ class Idefics3Encoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -12,6 +12,7 @@ import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -357,7 +358,7 @@ class IJepaOutput(nn.Module):
return hidden_states
class IJepaLayer(nn.Module):
class IJepaLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: IJepaConfig) -> None:
@ -423,15 +424,7 @@ class IJepaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -401,7 +402,7 @@ class ImageGPTMLP(nn.Module):
return hidden_states
class ImageGPTBlock(nn.Module):
class ImageGPTBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
@ -719,29 +720,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -744,7 +745,7 @@ class InformerProbSparseAttention(nn.Module):
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
class InformerConvLayer(nn.Module):
class InformerConvLayer(GradientCheckpointingLayer):
def __init__(self, c_in):
super().__init__()
self.downConv = nn.Conv1d(
@ -767,7 +768,7 @@ class InformerConvLayer(nn.Module):
return x
class InformerEncoderLayer(nn.Module):
class InformerEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InformerConfig):
super().__init__()
self.embed_dim = config.d_model
@ -845,7 +846,7 @@ class InformerEncoderLayer(nn.Module):
return outputs
class InformerDecoderLayer(nn.Module):
class InformerDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -1086,27 +1087,15 @@ class InformerEncoder(InformerPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
if conv_layer is not None:
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
hidden_states = layer_outputs[0]
@ -1299,35 +1288,18 @@ class InformerDecoder(InformerPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
)
@ -433,7 +434,7 @@ class InformerProbSparseAttention(nn.Module):
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
class InformerConvLayer(nn.Module):
class InformerConvLayer(GradientCheckpointingLayer):
def __init__(self, c_in):
super().__init__()
self.downConv = nn.Conv1d(
@ -610,27 +611,15 @@ class InformerEncoder(TimeSeriesTransformerEncoder):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
if conv_layer is not None:
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if conv_layer is not None:
output = conv_layer(layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
hidden_states = layer_outputs[0]

View File

@ -427,7 +427,7 @@ class InstructBlipEncoder(nn.Module):
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
@ -889,11 +889,11 @@ class InstructBlipQFormerEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
query_length=query_length,
)
hidden_states = layer_outputs[0]

View File

@ -356,7 +356,7 @@ class InstructBlipVideoEncoder(nn.Module):
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
@ -750,11 +750,11 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
query_length=query_length,
)
hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -383,7 +384,7 @@ class InternVLVisionMLP(nn.Module):
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
class InternVLVisionLayer(nn.Module):
class InternVLVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: InternVLVisionConfig) -> None:
@ -452,12 +453,7 @@ class InternVLVisionEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, output_attentions
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ import torch.utils.checkpoint
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -334,7 +335,7 @@ class InternVLVisionMLP(CLIPMLP):
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
class InternVLVisionLayer(nn.Module):
class InternVLVisionLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: InternVLVisionConfig) -> None:
@ -403,12 +404,7 @@ class InternVLVisionEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, output_attentions
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]

View File

@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
@ -763,7 +764,7 @@ JETMOE_ATTENTION_CLASSES = {
}
class JetMoeBlock(nn.Module):
class JetMoeBlock(GradientCheckpointingLayer):
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
"""
Initialize the JetMoeBlock module.
@ -967,28 +968,15 @@ class JetMoeModel(JetMoePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_ids,
past_key_values,
causal_mask,
output_attentions,
output_router_logits,
use_cache,
use_reentrant=False,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -404,7 +405,7 @@ class Kosmos2VisionMLP(nn.Module):
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
class Kosmos2VisionEncoderLayer(nn.Module):
class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -521,21 +522,12 @@ class Kosmos2VisionEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -840,7 +832,7 @@ class Kosmos2TextFFN(nn.Module):
return hidden_states
class Kosmos2TextBlock(nn.Module):
class Kosmos2TextBlock(GradientCheckpointingLayer):
def __init__(self, config: Kosmos2TextConfig):
super().__init__()
self.embed_dim = config.embed_dim
@ -1138,34 +1130,18 @@ class Kosmos2TextTransformer(nn.Module):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -358,7 +359,7 @@ class LayoutLMOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
class LayoutLMLayer(nn.Module):
class LayoutLMLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -484,27 +485,15 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
@ -261,7 +262,7 @@ class LayoutLMv2Output(nn.Module):
return hidden_states
class LayoutLMv2Layer(nn.Module):
class LayoutLMv2Layer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -436,25 +437,14 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
hidden_states = layer_outputs[0]
if output_attentions:

View File

@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
QuestionAnsweringModelOutput,
@ -358,7 +359,7 @@ class LayoutLMv3Attention(nn.Module):
# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
class LayoutLMv3Layer(nn.Module):
class LayoutLMv3Layer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -514,25 +515,14 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos,
rel_2d_pos,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
hidden_states = layer_outputs[0]
if output_attentions:

View File

@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
@ -900,7 +901,7 @@ class LEDDecoderAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
class LEDEncoderLayer(nn.Module):
class LEDEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LEDConfig, layer_id: int):
super().__init__()
self.embed_dim = config.d_model
@ -962,7 +963,7 @@ class LEDEncoderLayer(nn.Module):
return (hidden_states,) + attn_outputs[1:]
class LEDDecoderLayer(nn.Module):
class LEDDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LEDConfig):
super().__init__()
self.embed_dim = config.d_model
@ -1680,27 +1681,15 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
@ -1943,33 +1932,17 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
combined_attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]

Some files were not shown because too many files have changed in this diff Show More