mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Apply GradientCheckpointingLayer to the whole repo (#38913)
* first batch (4) * align * altclip * beit * bert * yolos * dino, pvt_v2 * bark, bart, bert_generation * big_bird, biogpt * blnderbot, bloom * bridgetower * camambert, canine, chameleon * chinese clip, clap, clip * codegen, conditional detr, convbert * dab_detr, data2vec * dbrx, deberta * deberta, decicion_tranformer, deformable_detr * deit, deta, mctct * detr, dinov2, distilbert * donut, dpt, electra * ernie, esm, falcon * flava, fnet, falcon_mamba * focalnet, git, gpt2 * gpt - bigcode, neo, neox * gptj, groupvit * idefics2, idefics3 * ijepa, imagegpt, internvl * jetmoe, kosmos2, layoutlm * layoutlm2-3, led * lilt, longformer, longt5, luke * m2m, mamba1-2 * marian, markuplm, mask2former * maskformer * mbart, megatron_bert, mimi * mixtral, mlcd * mobilevit1-2, modernbert * moshi, mpt, mra * mt5, musicgen * mvp, nemotron * nllb_moe * nystromformer, omdet_turbo * opt, owlvit, owlv2 * pegasus, pegasus_x, presimmon * phimoe, pix2struct, pixtral * plbart, pop2piano, prophetnet * qwen2* * qwen2, qwen3 moe, rec gemma * rembert * roberta * roberta prelayernorm * roc_bert, roformer, rwkv * sam, sam_hq * seggpt, smolvlm, speech_to_text * splinter, stablelm, swin * swin2sr, switch_transformer, t5, table_transformer * tapas, time_series_tranformer, timesformer * trocr, tvp, umt5 * videomae, vilt, visual_bert * vit, vit_mae, vit_msn * vitpose_backbone, vits, vivit * whisper. x_clip, xglm * xlm_roberta, xmod * yoso * zamba * vitdet, wav2vec2, wav2vec2_bert * unispeech, wav2vec_conformer * wavlm * speecht5 * swinv2 * sew / _d * seamless_mt4 / _v2 * deprecated models update * bros * gemma2, gemma3 * got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer * fixup * Add use_cache=False and past_key_value=None to GradientCheckpointingLayer * fixup * fix prophetnet * fix bigbird_pegasus * fix blenderbot * fix mbart * fix mvp * fix zamba2 * fix bart * fix blenderbot_small * fix codegen * Update gradient checkpointing layer to support more past_key_values arg names * fix data2vec vision * fix deformable_detr * fix gptj * fix led * fix m2m_100 * add comment * fix nnlb_moe * Fix pegasus_x * fix plbart * udop * fix-copies: beit, wav2vec2 * fix gpt_bigcode * fixup * fix t5 * fix switch_transformers * fix longt5 * fix mt5 * update tapas * fix blip2 * update blip * fix musicgen * fix gpt2, trocr * fix copies * !!! Revert zamba, mllama * update autoformer * update bros * update args / kwargs for BERT and copies * 2nd round of updates * update conditional detr * Pass encoder_hidden_states as positional arg * Update to pass encoder_decoder_position_bias as positional arg * fixup * biogpt modular * modular gemma2 * modular gemma3 * modular gpt_neox * modular informer * modular internvl * modular mixtral * modular mlcd * modular modernbert * modular phi * modular qwen2_5_omni * modular qwen2_5_vl * modular sam_hq * modular sew * wav2vec2_bert * modular wav2vec2_conformer * modular wavlm * fixup * Update by modular instructblipvideo * modular data2vec_audio * nit modular mistral * apply modular minimax * fix modular moonshine * revert zamba2 * fix mask2former * refactor idefics
This commit is contained in:
parent
07aab1af1e
commit
84d19be41e
@ -16,6 +16,11 @@ from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GradientCheckpointingLayer(nn.Module):
|
||||
"""Base class for layers with gradient checkpointing.
|
||||
@ -44,5 +49,35 @@ class GradientCheckpointingLayer(nn.Module):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
do_warn = False
|
||||
layer_name = self.__class__.__name__
|
||||
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
|
||||
|
||||
if "use_cache" in kwargs and kwargs["use_cache"]:
|
||||
kwargs["use_cache"] = False
|
||||
message += " `use_cache=False`,"
|
||||
do_warn = True
|
||||
|
||||
# different names for the same thing in different layers
|
||||
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
|
||||
kwargs["past_key_value"] = None
|
||||
message += " `past_key_value=None`,"
|
||||
do_warn = True
|
||||
|
||||
if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
|
||||
kwargs["past_key_values"] = None
|
||||
message += " `past_key_values=None`,"
|
||||
do_warn = True
|
||||
|
||||
if "layer_past" in kwargs and kwargs["layer_past"] is not None:
|
||||
kwargs["layer_past"] = None
|
||||
message += " `layer_past=None`,"
|
||||
do_warn = True
|
||||
|
||||
# warn if anything was changed
|
||||
if do_warn:
|
||||
message = message.rstrip(",") + "."
|
||||
logger.warning(message)
|
||||
|
||||
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
@ -23,6 +23,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithNoAttention,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -827,7 +828,7 @@ class AlignTextOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
|
||||
class AlignTextLayer(nn.Module):
|
||||
class AlignTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -953,27 +954,15 @@ class AlignTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -418,7 +419,7 @@ class AltRobertaOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
|
||||
class AltRobertaLayer(nn.Module):
|
||||
class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -544,27 +545,15 @@ class AltRobertaEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
@ -732,7 +721,7 @@ class AltCLIPMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AltCLIPEncoderLayer(nn.Module):
|
||||
class AltCLIPEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: AltCLIPConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -848,21 +837,12 @@ class AltCLIPEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -22,6 +22,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -282,7 +283,7 @@ class ASTOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
|
||||
class ASTLayer(nn.Module):
|
||||
class ASTLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
@ -349,16 +350,7 @@ class ASTEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
|
@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
@ -670,7 +671,7 @@ class AutoformerAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class AutoformerEncoderLayer(nn.Module):
|
||||
class AutoformerEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: AutoformerConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -744,7 +745,7 @@ class AutoformerEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class AutoformerDecoderLayer(nn.Module):
|
||||
class AutoformerDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: AutoformerConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1042,21 +1043,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1186,6 +1178,12 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
# expand encoder attention mask
|
||||
@ -1228,38 +1226,17 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
(hidden_states, residual_trend) = layer_outputs[0]
|
||||
trend = trend + residual_trend
|
||||
|
||||
|
@ -31,6 +31,7 @@ from ...generation.logits_process import (
|
||||
)
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
||||
from ...utils import (
|
||||
@ -309,7 +310,7 @@ class BarkMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BarkBlock(nn.Module):
|
||||
class BarkBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, is_causal=False):
|
||||
super().__init__()
|
||||
|
||||
@ -606,25 +607,14 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_values=past_layer_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_values=past_layer_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
|
@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -270,7 +271,7 @@ class BartAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class BartEncoderLayer(nn.Module):
|
||||
class BartEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -341,7 +342,7 @@ class BartEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class BartDecoderLayer(nn.Module):
|
||||
class BartDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -875,21 +876,12 @@ class BartEncoder(BartPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1137,35 +1129,18 @@ class BartDecoder(BartPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -26,6 +26,7 @@ from torch import Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BackboneOutput,
|
||||
BaseModelOutput,
|
||||
@ -497,7 +498,7 @@ class BeitOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BeitLayer(nn.Module):
|
||||
class BeitLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
|
||||
@ -525,7 +526,7 @@ class BeitLayer(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[tuple[int]] = None,
|
||||
resolution: Optional[tuple[int, int]] = None,
|
||||
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
||||
@ -695,25 +696,14 @@ class BeitEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
resolution=resolution,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -522,7 +523,7 @@ class BertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
class BertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -647,27 +648,15 @@ class BertEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -275,7 +276,7 @@ class BertGenerationOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration
|
||||
class BertGenerationLayer(nn.Module):
|
||||
class BertGenerationLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -401,27 +402,15 @@ class BertEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -1419,7 +1420,7 @@ class BigBirdOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BigBirdLayer(nn.Module):
|
||||
class BigBirdLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, seed=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -1593,35 +1594,19 @@ class BigBirdEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
band_mask,
|
||||
from_mask,
|
||||
to_mask,
|
||||
blocked_encoder_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
band_mask,
|
||||
from_mask,
|
||||
to_mask,
|
||||
blocked_encoder_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
band_mask,
|
||||
from_mask,
|
||||
to_mask,
|
||||
blocked_encoder_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -1333,7 +1334,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class BigBirdPegasusEncoderLayer(nn.Module):
|
||||
class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BigBirdPegasusConfig, seed=None):
|
||||
super().__init__()
|
||||
self.attention_type = config.attention_type
|
||||
@ -1420,7 +1421,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
|
||||
self.self_attn.set_attention_type(value)
|
||||
|
||||
|
||||
class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1947,31 +1948,17 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
band_mask,
|
||||
from_mask,
|
||||
to_mask,
|
||||
blocked_encoder_mask,
|
||||
blocked_encoder_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
band_mask=band_mask,
|
||||
from_mask=from_mask,
|
||||
to_mask=to_mask,
|
||||
from_blocked_mask=blocked_encoder_mask,
|
||||
to_blocked_mask=blocked_encoder_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
band_mask=band_mask,
|
||||
from_mask=from_mask,
|
||||
to_mask=to_mask,
|
||||
from_blocked_mask=blocked_encoder_mask,
|
||||
to_blocked_mask=blocked_encoder_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -2297,35 +2284,18 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -20,7 +20,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -32,6 +31,7 @@ from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -248,7 +248,7 @@ class BioGptAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class BioGptDecoderLayer(nn.Module):
|
||||
class BioGptDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -646,30 +646,17 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
position_ids,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""PyTorch BioGPT model."""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -473,30 +472,17 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
position_ids,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -270,7 +271,7 @@ class BlenderbotAttention(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
|
||||
class BlenderbotEncoderLayer(nn.Module):
|
||||
class BlenderbotEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -339,7 +340,7 @@ class BlenderbotEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
|
||||
class BlenderbotDecoderLayer(nn.Module):
|
||||
class BlenderbotDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -825,21 +826,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1090,35 +1082,18 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -254,7 +255,7 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
|
||||
class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
class BlenderbotSmallEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -326,7 +327,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
|
||||
class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -812,21 +813,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1073,35 +1065,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -552,7 +552,7 @@ class BlipEncoder(nn.Module):
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -531,7 +531,7 @@ class Blip2Encoder(nn.Module):
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -992,11 +992,11 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -27,6 +27,7 @@ from torch.nn import functional as F
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -366,7 +367,7 @@ class BloomMLP(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
class BloomBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -605,29 +606,16 @@ class BloomModel(BloomPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
past_key_values,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache:
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN, QuickGELUActivation
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -662,7 +663,7 @@ class BridgeTowerBertCrossLayer(nn.Module):
|
||||
return layer_output
|
||||
|
||||
|
||||
class BridgeTowerTextLayer(nn.Module):
|
||||
class BridgeTowerTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -788,27 +789,15 @@ class BridgeTowerTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -428,7 +429,7 @@ class BrosOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BrosLayer(nn.Module):
|
||||
class BrosLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -550,34 +551,16 @@ class BrosEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -478,7 +479,7 @@ class CamembertOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert
|
||||
class CamembertLayer(nn.Module):
|
||||
class CamembertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -604,27 +605,15 @@ class CamembertEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
ModelOutput,
|
||||
@ -672,7 +673,7 @@ class CanineOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CanineLayer(nn.Module):
|
||||
class CanineLayer(GradientCheckpointingLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@ -779,16 +780,7 @@ class CanineEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
|
@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -383,7 +384,7 @@ class ChameleonAttention(nn.Module):
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
|
||||
class ChameleonDecoderLayer(nn.Module):
|
||||
class ChameleonDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ChameleonConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -458,7 +459,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class ChameleonSwinDecoderLayer(nn.Module):
|
||||
class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ChameleonConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1011,28 +1012,16 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -23,6 +23,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -577,7 +578,7 @@ class ChineseCLIPVisionMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
|
||||
class ChineseCLIPTextLayer(nn.Module):
|
||||
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -663,7 +664,7 @@ class ChineseCLIPTextLayer(nn.Module):
|
||||
return layer_output
|
||||
|
||||
|
||||
class ChineseCLIPVisionLayer(nn.Module):
|
||||
class ChineseCLIPVisionLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ChineseCLIPConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -816,27 +817,15 @@ class ChineseCLIPTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
@ -920,17 +909,10 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -691,7 +692,7 @@ class ClapAudioLayer(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio
|
||||
class ClapAudioStage(nn.Module):
|
||||
class ClapAudioStage(GradientCheckpointingLayer):
|
||||
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -928,14 +929,9 @@ class ClapAudioEncoder(nn.Module):
|
||||
|
||||
input_dimensions = self.input_resolutions[i]
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1355,7 +1351,7 @@ class ClapTextOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
|
||||
class ClapTextLayer(nn.Module):
|
||||
class ClapTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -1481,27 +1477,15 @@ class ClapTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
@ -393,7 +394,7 @@ class CLIPMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
class CLIPEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -575,21 +576,12 @@ class CLIPEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
@ -374,7 +375,7 @@ class CLIPSegMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
|
||||
class CLIPSegEncoderLayer(nn.Module):
|
||||
class CLIPSegEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: CLIPSegConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -539,22 +540,12 @@ class CLIPSegEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
|
@ -24,6 +24,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -245,7 +246,7 @@ class CodeGenMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
|
||||
class CodeGenBlock(nn.Module):
|
||||
class CodeGenBlock(GradientCheckpointingLayer):
|
||||
# Ignore copy
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
@ -437,29 +438,16 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states=hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -23,6 +23,7 @@ from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends
|
||||
@ -827,7 +828,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class ConditionalDetrDecoderLayer(nn.Module):
|
||||
class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ConditionalDetrConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1297,31 +1298,18 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
||||
pos_transformation = self.query_scale(hidden_states)
|
||||
# apply transformation
|
||||
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
object_queries,
|
||||
query_position_embeddings,
|
||||
query_sine_embed,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
object_queries=object_queries,
|
||||
query_position_embeddings=query_position_embeddings,
|
||||
query_sine_embed=query_sine_embed,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
is_first=(idx == 0),
|
||||
)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
object_queries,
|
||||
query_position_embeddings,
|
||||
query_sine_embed,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
is_first=(idx == 0),
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
@ -532,7 +533,7 @@ class ConvBertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ConvBertLayer(nn.Module):
|
||||
class ConvBertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -620,25 +621,14 @@ class ConvBertEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
@ -23,6 +23,7 @@ from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -702,7 +703,7 @@ class DabDetrDecoderLayerFFN(nn.Module):
|
||||
|
||||
|
||||
# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig
|
||||
class DabDetrEncoderLayer(nn.Module):
|
||||
class DabDetrEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DabDetrConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -764,7 +765,7 @@ class DabDetrEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr
|
||||
class DabDetrDecoderLayer(nn.Module):
|
||||
class DabDetrDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DabDetrConfig, is_first: bool = False):
|
||||
super().__init__()
|
||||
self.self_attn = DabDetrDecoderLayerSelfAttention(config)
|
||||
@ -976,21 +977,12 @@ class DabDetrEncoder(DabDetrPreTrainedModel):
|
||||
# we add object_queries * pos_scaler as extra input to the encoder_layer
|
||||
scaled_object_queries = object_queries * pos_scales
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
scaled_object_queries,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
object_queries=scaled_object_queries,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
object_queries=scaled_object_queries,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1138,29 +1130,16 @@ class DabDetrDecoder(DabDetrPreTrainedModel):
|
||||
reference_anchor_size[..., 1] / obj_center[..., 3]
|
||||
).unsqueeze(-1)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
object_queries,
|
||||
query_pos,
|
||||
query_sine_embed,
|
||||
encoder_hidden_states,
|
||||
memory_key_padding_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
object_queries=object_queries,
|
||||
query_position_embeddings=query_pos,
|
||||
query_sine_embed=query_sine_embed,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=memory_key_padding_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
object_queries,
|
||||
query_pos,
|
||||
query_sine_embed,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=memory_key_padding_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# iter update
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -33,6 +33,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@ -51,7 +52,7 @@ if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
class Data2VecAudioConvLayer(nn.Module):
|
||||
class Data2VecAudioConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||||
@ -155,13 +156,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
if self._requires_grad and self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
conv_layer.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -357,7 +352,7 @@ class Data2VecAudioFeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Data2VecAudioEncoderLayer(nn.Module):
|
||||
class Data2VecAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = Data2VecAudioAttention(
|
||||
@ -441,17 +436,9 @@ class Data2VecAudioEncoder(nn.Module):
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if skip_the_layer:
|
||||
|
@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -375,7 +376,7 @@ class Data2VecTextOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
|
||||
class Data2VecTextLayer(nn.Module):
|
||||
class Data2VecTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -501,27 +502,15 @@ class Data2VecTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -497,7 +498,7 @@ class Data2VecVisionOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
|
||||
class Data2VecVisionLayer(nn.Module):
|
||||
class Data2VecVisionLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(
|
||||
@ -527,7 +528,7 @@ class Data2VecVisionLayer(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[tuple[int]] = None,
|
||||
resolution: Optional[tuple[int, int]] = None,
|
||||
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
||||
@ -699,25 +700,14 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
resolution=resolution,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import Wav2Vec2BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ..wav2vec2.modeling_wav2vec2 import (
|
||||
@ -38,7 +39,7 @@ from ..wav2vec2.modeling_wav2vec2 import (
|
||||
from .configuration_data2vec_audio import Data2VecAudioConfig
|
||||
|
||||
|
||||
class Data2VecAudioConvLayer(nn.Module):
|
||||
class Data2VecAudioConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||||
|
@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
@ -724,7 +725,7 @@ class DbrxFFN(nn.Module):
|
||||
return out, weights
|
||||
|
||||
|
||||
class DbrxBlock(nn.Module):
|
||||
class DbrxBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DbrxConfig, block_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.d_model
|
||||
@ -947,29 +948,16 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
block_outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
block_outputs = block(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
block_outputs = block(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = block_outputs[0]
|
||||
|
||||
|
@ -22,6 +22,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
@ -492,7 +493,7 @@ class DebertaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DebertaLayer(nn.Module):
|
||||
class DebertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = DebertaAttention(config)
|
||||
@ -580,25 +581,14 @@ class DebertaEncoder(nn.Module):
|
||||
|
||||
rel_embeddings = self.get_rel_embedding()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states, att_m = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states,
|
||||
relative_pos,
|
||||
rel_embeddings,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states, att_m = layer_module(
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states=query_states,
|
||||
relative_pos=relative_pos,
|
||||
rel_embeddings=rel_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states, att_m = layer_module(
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states=query_states,
|
||||
relative_pos=relative_pos,
|
||||
rel_embeddings=rel_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
@ -418,7 +419,7 @@ class DebertaV2Output(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
||||
class DebertaV2Layer(nn.Module):
|
||||
class DebertaV2Layer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = DebertaV2Attention(config)
|
||||
@ -655,25 +656,14 @@ class DebertaV2Encoder(nn.Module):
|
||||
next_kv = hidden_states
|
||||
rel_embeddings = self.get_rel_embedding()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
output_states, attn_weights = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states,
|
||||
relative_pos,
|
||||
rel_embeddings,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
output_states, attn_weights = layer_module(
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states=query_states,
|
||||
relative_pos=relative_pos,
|
||||
rel_embeddings=rel_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
output_states, attn_weights = layer_module(
|
||||
next_kv,
|
||||
attention_mask,
|
||||
query_states=query_states,
|
||||
relative_pos=relative_pos,
|
||||
rel_embeddings=rel_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (attn_weights,)
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
@ -360,7 +361,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
|
||||
class DecisionTransformerGPT2Block(nn.Module):
|
||||
class DecisionTransformerGPT2Block(GradientCheckpointingLayer):
|
||||
# Ignore copy
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
@ -654,31 +655,17 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_values if not (self.gradient_checkpointing and self.training) else None,
|
||||
cache_position,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
|
@ -27,6 +27,7 @@ from torch import Tensor, nn
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
@ -759,7 +760,7 @@ class DeformableDetrMultiheadAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
|
||||
class DeformableDetrEncoderLayer(nn.Module):
|
||||
class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DeformableDetrConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -848,7 +849,7 @@ class DeformableDetrEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class DeformableDetrDecoderLayer(nn.Module):
|
||||
class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DeformableDetrConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1126,29 +1127,16 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
for i, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings,
|
||||
reference_points,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1273,31 +1261,17 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
reference_points=reference_points_input,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -347,7 +348,7 @@ class DeiTOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
|
||||
class DeiTLayer(nn.Module):
|
||||
class DeiTLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: DeiTConfig) -> None:
|
||||
@ -414,15 +415,7 @@ class DeiTEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -39,6 +39,7 @@ from ....file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutput
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....pytorch_utils import meshgrid
|
||||
@ -909,7 +910,7 @@ class DetaEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class DetaDecoderLayer(nn.Module):
|
||||
class DetaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DetaConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1341,29 +1342,16 @@ class DetaDecoder(DetaPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
reference_points=reference_points_input,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
reference_points=reference_points_input,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add
|
||||
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ....integrations.fsdp import is_fsdp_managed_module
|
||||
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ....modeling_utils import (
|
||||
PreTrainedModel,
|
||||
@ -377,7 +378,7 @@ class MCTCTOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MCTCTLayer(nn.Module):
|
||||
class MCTCTLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: MCTCTConfig):
|
||||
super().__init__()
|
||||
|
||||
@ -591,20 +592,11 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -438,7 +439,7 @@ class NezhaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NezhaLayer(nn.Module):
|
||||
class NezhaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -563,27 +564,15 @@ class NezhaEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
@ -339,7 +340,7 @@ class OpenLlamaAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class OpenLlamaDecoderLayer(nn.Module):
|
||||
class OpenLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: OpenLlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -631,25 +632,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -452,7 +453,7 @@ class QDQBertOutput(nn.Module):
|
||||
|
||||
|
||||
# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
|
||||
class QDQBertLayer(nn.Module):
|
||||
class QDQBertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.seq_len_dim = 1
|
||||
@ -568,32 +569,15 @@ class QDQBertEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -447,7 +448,7 @@ class RealmOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RealmLayer(nn.Module):
|
||||
class RealmLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -572,27 +573,15 @@ class RealmEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....utils import add_start_docstrings, logging, replace_return_docstrings
|
||||
@ -263,7 +264,7 @@ class Speech2Text2Attention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class Speech2Text2DecoderLayer(nn.Module):
|
||||
class Speech2Text2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Speech2Text2Config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -612,31 +613,17 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -25,6 +25,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....utils import (
|
||||
ModelOutput,
|
||||
@ -346,7 +347,7 @@ class CausalSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
class Block(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
@ -540,16 +541,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
layer_past,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(hidden_states, layer_past, use_cache, output_attentions)
|
||||
outputs = block(hidden_states, layer_past, use_cache, output_attentions)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -483,7 +484,7 @@ class TvltOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TvltLayer(nn.Module):
|
||||
class TvltLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config):
|
||||
@ -546,16 +547,7 @@ class TvltEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -853,15 +845,7 @@ class TvltDecoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -390,7 +391,7 @@ VIT_HYBRID_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class ViTHybridLayer(nn.Module):
|
||||
class ViTHybridLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: ViTHybridConfig) -> None:
|
||||
@ -457,15 +458,7 @@ class ViTHybridEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch import Tensor, nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from ....activations import ACT2FN
|
||||
from ....modeling_layers import GradientCheckpointingLayer
|
||||
from ....modeling_outputs import BaseModelOutput
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....utils import (
|
||||
@ -1090,7 +1091,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
||||
return predict_relative_pos_embeddings
|
||||
|
||||
|
||||
class XLMProphetNetEncoderLayer(nn.Module):
|
||||
class XLMProphetNetEncoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
Encoder block for XLMProphetnet
|
||||
"""
|
||||
@ -1133,7 +1134,7 @@ class XLMProphetNetEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class XLMProphetNetDecoderLayer(nn.Module):
|
||||
class XLMProphetNetDecoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
Decoder block for XLMProphetnet
|
||||
"""
|
||||
@ -1320,21 +1321,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1554,41 +1546,21 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
encoder_hidden_states,
|
||||
extended_encoder_attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
extended_predict_attention_mask,
|
||||
main_relative_position_buckets,
|
||||
predict_relative_position_buckets,
|
||||
position_ids,
|
||||
None,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attn_mask=extended_encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
extended_predict_attention_mask=extended_predict_attention_mask,
|
||||
main_relative_position_buckets=main_relative_position_buckets,
|
||||
predict_relative_position_buckets=predict_relative_position_buckets,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attn_mask=extended_encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
extended_predict_attention_mask=extended_predict_attention_mask,
|
||||
main_relative_position_buckets=main_relative_position_buckets,
|
||||
predict_relative_position_buckets=predict_relative_position_buckets,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -23,6 +23,7 @@ from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -677,7 +678,7 @@ class DetrEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class DetrDecoderLayer(nn.Module):
|
||||
class DetrDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DetrConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1045,25 +1046,15 @@ class DetrDecoder(DetrPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
object_queries=object_queries,
|
||||
query_position_embeddings=query_position_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
object_queries,
|
||||
query_position_embeddings,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -382,7 +383,7 @@ class Dinov2SwiGLUFFN(nn.Module):
|
||||
return self.weights_out(hidden)
|
||||
|
||||
|
||||
class Dinov2Layer(nn.Module):
|
||||
class Dinov2Layer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
@ -458,15 +459,7 @@ class Dinov2Encoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -28,6 +28,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -399,7 +400,7 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
|
||||
return self.weights_out(hidden)
|
||||
|
||||
|
||||
class Dinov2WithRegistersLayer(nn.Module):
|
||||
class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
|
||||
@ -476,15 +477,7 @@ class Dinov2WithRegistersEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -31,6 +31,7 @@ from ...configuration_utils import PretrainedConfig
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
@ -441,7 +442,7 @@ DISTILBERT_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
class TransformerBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
@ -537,21 +538,12 @@ class Transformer(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_state,
|
||||
attn_mask,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_state,
|
||||
attn_mask,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_state,
|
||||
attn_mask,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_state = layer_outputs[-1]
|
||||
|
||||
|
@ -27,6 +27,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
@ -706,7 +707,7 @@ class DonutSwinLayer(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
|
||||
class DonutSwinStage(nn.Module):
|
||||
class DonutSwinStage(GradientCheckpointingLayer):
|
||||
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -816,19 +817,9 @@ class DonutSwinEncoder(nn.Module):
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
hidden_states_before_downsampling = layer_outputs[1]
|
||||
|
@ -29,6 +29,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -469,7 +470,7 @@ class DPTViTOutput(nn.Module):
|
||||
|
||||
|
||||
# copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput
|
||||
class DPTViTLayer(nn.Module):
|
||||
class DPTViTLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: DPTConfig) -> None:
|
||||
@ -536,15 +537,7 @@ class DPTViTEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -436,7 +437,7 @@ class ElectraOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
|
||||
class ElectraLayer(nn.Module):
|
||||
class ElectraLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -562,27 +563,15 @@ class ElectraEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -361,7 +362,7 @@ class ErnieOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie
|
||||
class ErnieLayer(nn.Module):
|
||||
class ErnieLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -487,27 +488,15 @@ class ErnieEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -599,7 +600,7 @@ class EsmOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EsmLayer(nn.Module):
|
||||
class EsmLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -725,27 +726,15 @@ class EsmEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -556,7 +557,7 @@ FALCON_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class FalconDecoderLayer(nn.Module):
|
||||
class FalconDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: FalconConfig, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -836,33 +837,18 @@ class FalconModel(FalconPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import MambaCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
|
||||
@ -405,7 +406,7 @@ class FalconMambaRMSNorm(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaBlock(nn.Module):
|
||||
class FalconMambaBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -620,17 +621,12 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
hidden_states = inputs_embeds
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for mixer_block in self.layers:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
|
||||
)
|
||||
else:
|
||||
hidden_states = mixer_block(
|
||||
hidden_states,
|
||||
cache_params=cache_params,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = mixer_block(
|
||||
hidden_states,
|
||||
cache_params=cache_params,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -25,6 +25,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
@ -577,7 +578,7 @@ class FlavaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlavaLayer(nn.Module):
|
||||
class FlavaLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: FlavaPossibleConfigs) -> None:
|
||||
@ -648,16 +649,7 @@ class FlavaEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -31,6 +31,7 @@ if is_scipy_available():
|
||||
from scipy import linalg
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -235,7 +236,7 @@ class FNetOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FNetLayer(nn.Module):
|
||||
class FNetLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -276,10 +277,7 @@ class FNetEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states)
|
||||
layer_outputs = layer_module(hidden_states)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
@ -455,7 +456,7 @@ class FocalNetLayer(nn.Module):
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FocalNetStage(nn.Module):
|
||||
class FocalNetStage(GradientCheckpointingLayer):
|
||||
def __init__(self, config, index, input_resolution):
|
||||
super().__init__()
|
||||
|
||||
@ -560,14 +561,7 @@ class FocalNetEncoder(nn.Module):
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
for i, stage_module in enumerate(self.stages):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
stage_outputs = self._gradient_checkpointing_func(
|
||||
stage_module.__call__,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
)
|
||||
else:
|
||||
stage_outputs = stage_module(hidden_states, input_dimensions)
|
||||
stage_outputs = stage_module(hidden_states, input_dimensions)
|
||||
|
||||
hidden_states = stage_outputs[0]
|
||||
hidden_states_before_downsampling = stage_outputs[1]
|
||||
|
@ -19,7 +19,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -30,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -238,7 +238,7 @@ class Gemma2Attention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -466,30 +466,17 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -25,6 +24,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
@ -303,7 +303,7 @@ class Gemma2Attention(GemmaAttention):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -449,30 +449,17 @@ class Gemma2Model(GemmaModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -22,7 +22,6 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -34,6 +33,7 @@ from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
@ -364,7 +364,7 @@ class Gemma3Attention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -581,32 +581,18 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -27,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
@ -443,7 +443,7 @@ class Gemma3Attention(Gemma2Attention):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -632,32 +632,18 @@ class Gemma3TextModel(Gemma2Model):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPast,
|
||||
@ -343,7 +344,7 @@ class GitOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GitLayer(nn.Module):
|
||||
class GitLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -441,24 +442,14 @@ class GitEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
pixel_values_present,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
pixel_values_present,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
@ -723,7 +714,7 @@ class GitVisionAttention(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
|
||||
class GitVisionEncoderLayer(nn.Module):
|
||||
class GitVisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GitVisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -840,21 +831,12 @@ class GitVisionEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -31,6 +31,7 @@ import torch.nn.functional as F
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -192,7 +193,7 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class GotOcr2VisionLayer(nn.Module):
|
||||
class GotOcr2VisionLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@ -463,13 +464,7 @@ class GotOcr2VisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -30,6 +30,7 @@ from ...activations import ACT2FN, get_activation
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -368,7 +369,7 @@ class GPT2MLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2Block(nn.Module):
|
||||
class GPT2Block(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -922,32 +923,18 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
past_key_values if not (self.gradient_checkpointing and self.training) else None,
|
||||
cache_position,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -558,7 +559,7 @@ GPTBIGCODE_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
class GPTBigCodeBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -759,6 +760,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -891,29 +898,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache:
|
||||
|
@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -431,7 +432,7 @@ class GPTNeoMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoBlock(nn.Module):
|
||||
class GPTNeoBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -635,27 +636,15 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache:
|
||||
|
@ -14,6 +14,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -190,7 +191,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
class GPTNeoXLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
@ -415,32 +416,18 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
|
@ -9,6 +9,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -177,7 +178,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
class GPTNeoXLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
@ -362,32 +363,18 @@ class GPTNeoXModel(LlamaModel, nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
|
@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -434,7 +435,7 @@ class GPTJMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTJBlock(nn.Module):
|
||||
class GPTJBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
||||
@ -738,29 +739,16 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states=hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
@ -692,7 +693,7 @@ class GroupViTAttention(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
|
||||
class GroupViTEncoderLayer(nn.Module):
|
||||
class GroupViTEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GroupViTConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -906,21 +907,12 @@ class GroupViTTextEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BackboneOutput,
|
||||
BaseModelOutput,
|
||||
@ -540,7 +541,7 @@ class HieraLayer(nn.Module):
|
||||
return (hidden_states, attn_weights)
|
||||
|
||||
|
||||
class HieraStage(nn.Module):
|
||||
class HieraStage(GradientCheckpointingLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@ -734,12 +735,7 @@ class HieraEncoder(nn.Module):
|
||||
for i, stage_module in enumerate(self.stages):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
stage_module.__call__, hidden_states, layer_head_mask, output_attentions
|
||||
)
|
||||
else:
|
||||
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -32,6 +32,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -107,7 +108,7 @@ class HubertSamePadLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HubertNoLayerNormConvLayer(nn.Module):
|
||||
class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||||
@ -128,7 +129,7 @@ class HubertNoLayerNormConvLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HubertLayerNormConvLayer(nn.Module):
|
||||
class HubertLayerNormConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||||
@ -155,7 +156,7 @@ class HubertLayerNormConvLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HubertGroupNormConvLayer(nn.Module):
|
||||
class HubertGroupNormConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||||
@ -212,13 +213,7 @@ class HubertFeatureEncoder(nn.Module):
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
if self._requires_grad and self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
conv_layer.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -417,7 +412,7 @@ class HubertFeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HubertEncoderLayer(nn.Module):
|
||||
class HubertEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = HubertAttention(
|
||||
@ -501,17 +496,9 @@ class HubertEncoder(nn.Module):
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if skip_the_layer:
|
||||
@ -579,7 +566,7 @@ class HubertAttnAdapterLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HubertEncoderLayerStableLayerNorm(nn.Module):
|
||||
class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = HubertAttention(
|
||||
@ -675,17 +662,9 @@ class HubertEncoderStableLayerNorm(nn.Module):
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if skip_the_layer:
|
||||
|
@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -668,7 +669,7 @@ class IdeficsAttention(nn.Module):
|
||||
|
||||
|
||||
# this was adapted from LlamaDecoderLayer
|
||||
class IdeficsDecoderLayer(nn.Module):
|
||||
class IdeficsDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -749,7 +750,7 @@ class IdeficsDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||
class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1185,95 +1186,32 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
def vblock(
|
||||
main_block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
cross_attention_gate,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
layer_idx,
|
||||
cross_layer_interval,
|
||||
gated_cross_attn_layers,
|
||||
cache_position,
|
||||
):
|
||||
# TODO(ls): Add cross attention values to respective lists
|
||||
if layer_idx % cross_layer_interval == 0:
|
||||
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
||||
outputs = xblock(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
cross_attention_gate=cross_attention_gate,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
past_key_value=None, # not implemented
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
layer_outputs = main_block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return layer_outputs
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
past_key_values = None
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
vblock,
|
||||
decoder_layer,
|
||||
# TODO(ls): Add cross attention values to respective lists
|
||||
if idx % self.cross_layer_interval == 0:
|
||||
cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval]
|
||||
outputs = cross_attn_block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
cross_attention_gate,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
idx,
|
||||
self.cross_layer_interval,
|
||||
self.gated_cross_attn_layers,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = vblock(
|
||||
decoder_layer,
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
cross_attention_gate=cross_attention_gate,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
layer_idx=idx,
|
||||
cross_layer_interval=self.cross_layer_interval,
|
||||
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
||||
cache_position=cache_position,
|
||||
past_key_value=None, # not implemented
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
@ -283,7 +284,7 @@ class IdeficsVisionMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
|
||||
class IdeficsVisionEncoderLayer(nn.Module):
|
||||
class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: IdeficsVisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -400,21 +401,12 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -339,7 +340,7 @@ class Idefics2MultiheadAttentionPoolingHead(nn.Module):
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
class Idefics2EncoderLayer(nn.Module):
|
||||
class Idefics2EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Idefics2VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -448,19 +449,11 @@ class Idefics2Encoder(nn.Module):
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ...cache_utils import DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -300,7 +301,7 @@ class Idefics3SimpleMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
|
||||
class Idefics3EncoderLayer(nn.Module):
|
||||
class Idefics3EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Idefics3VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -409,19 +410,11 @@ class Idefics3Encoder(nn.Module):
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -357,7 +358,7 @@ class IJepaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IJepaLayer(nn.Module):
|
||||
class IJepaLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: IJepaConfig) -> None:
|
||||
@ -423,15 +424,7 @@ class IJepaEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -401,7 +402,7 @@ class ImageGPTMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ImageGPTBlock(nn.Module):
|
||||
class ImageGPTBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -719,29 +720,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -744,7 +745,7 @@ class InformerProbSparseAttention(nn.Module):
|
||||
|
||||
|
||||
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
|
||||
class InformerConvLayer(nn.Module):
|
||||
class InformerConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, c_in):
|
||||
super().__init__()
|
||||
self.downConv = nn.Conv1d(
|
||||
@ -767,7 +768,7 @@ class InformerConvLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class InformerEncoderLayer(nn.Module):
|
||||
class InformerEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: InformerConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -845,7 +846,7 @@ class InformerEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class InformerDecoderLayer(nn.Module):
|
||||
class InformerDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1086,27 +1087,15 @@ class InformerEncoder(InformerPreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = conv_layer(layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = conv_layer(layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1299,35 +1288,18 @@ class InformerDecoder(InformerPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
)
|
||||
@ -433,7 +434,7 @@ class InformerProbSparseAttention(nn.Module):
|
||||
|
||||
|
||||
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
|
||||
class InformerConvLayer(nn.Module):
|
||||
class InformerConvLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, c_in):
|
||||
super().__init__()
|
||||
self.downConv = nn.Conv1d(
|
||||
@ -610,27 +611,15 @@ class InformerEncoder(TimeSeriesTransformerEncoder):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = conv_layer(layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = conv_layer(layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -427,7 +427,7 @@ class InstructBlipEncoder(nn.Module):
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -889,11 +889,11 @@ class InstructBlipQFormerEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -356,7 +356,7 @@ class InstructBlipVideoEncoder(nn.Module):
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -750,11 +750,11 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -31,6 +31,7 @@ from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -383,7 +384,7 @@ class InternVLVisionMLP(nn.Module):
|
||||
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
|
||||
|
||||
|
||||
class InternVLVisionLayer(nn.Module):
|
||||
class InternVLVisionLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
@ -452,12 +453,7 @@ class InternVLVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, output_attentions
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@ -334,7 +335,7 @@ class InternVLVisionMLP(CLIPMLP):
|
||||
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
|
||||
|
||||
|
||||
class InternVLVisionLayer(nn.Module):
|
||||
class InternVLVisionLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
@ -403,12 +404,7 @@ class InternVLVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, output_attentions
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -763,7 +764,7 @@ JETMOE_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class JetMoeBlock(nn.Module):
|
||||
class JetMoeBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
|
||||
"""
|
||||
Initialize the JetMoeBlock module.
|
||||
@ -967,28 +968,15 @@ class JetMoeModel(JetMoePreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
causal_mask,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -404,7 +405,7 @@ class Kosmos2VisionMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
|
||||
class Kosmos2VisionEncoderLayer(nn.Module):
|
||||
class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Kosmos2VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -521,21 +522,12 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -840,7 +832,7 @@ class Kosmos2TextFFN(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Kosmos2TextBlock(nn.Module):
|
||||
class Kosmos2TextBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Kosmos2TextConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.embed_dim
|
||||
@ -1138,34 +1130,18 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -358,7 +359,7 @@ class LayoutLMOutput(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
|
||||
class LayoutLMLayer(nn.Module):
|
||||
class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -484,27 +485,15 @@ class LayoutLMEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -261,7 +262,7 @@ class LayoutLMv2Output(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LayoutLMv2Layer(nn.Module):
|
||||
class LayoutLMv2Layer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -436,25 +437,14 @@ class LayoutLMv2Encoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos=rel_pos,
|
||||
rel_2d_pos=rel_2d_pos,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos=rel_pos,
|
||||
rel_2d_pos=rel_2d_pos,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos=rel_pos,
|
||||
rel_2d_pos=rel_2d_pos,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
|
@ -25,6 +25,7 @@ import torch.utils.checkpoint
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
@ -358,7 +359,7 @@ class LayoutLMv3Attention(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
|
||||
class LayoutLMv3Layer(nn.Module):
|
||||
class LayoutLMv3Layer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -514,25 +515,14 @@ class LayoutLMv3Encoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos,
|
||||
rel_2d_pos,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos=rel_pos,
|
||||
rel_2d_pos=rel_2d_pos,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
rel_pos=rel_pos,
|
||||
rel_2d_pos=rel_2d_pos,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
|
@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
@ -900,7 +901,7 @@ class LEDDecoderAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class LEDEncoderLayer(nn.Module):
|
||||
class LEDEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: LEDConfig, layer_id: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -962,7 +963,7 @@ class LEDEncoderLayer(nn.Module):
|
||||
return (hidden_states,) + attn_outputs[1:]
|
||||
|
||||
|
||||
class LEDDecoderLayer(nn.Module):
|
||||
class LEDDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: LEDConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1680,27 +1681,15 @@ class LEDEncoder(LEDPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
@ -1943,33 +1932,17 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user