Introduce GradientCheckpointingLayer (#37223)

* GradientCheckpointingLayer

* trigger

* Move GC layer to a separate file

* Update import

* Expose and document GC layer

* Fix dummy

* Apply to llama-based models

* Update modulars

* Update a few more models for consistency

* Update glm4

* Update Janus
This commit is contained in:
Pavel Iakubovskii 2025-04-22 11:33:31 +01:00 committed by GitHub
parent 413f9bbf80
commit 9167fadab9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 435 additions and 761 deletions

View File

@ -20,6 +20,10 @@ This page lists all the custom layers used by the library, as well as the utilit
Most of those are only useful if you are studying the code of the models in the library.
## Layers
[[autodoc]] GradientCheckpointingLayer
## Attention Functions
[[autodoc]] AttentionInterface

View File

@ -438,6 +438,7 @@ else:
]
_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
@ -911,6 +912,7 @@ if TYPE_CHECKING:
from .model_debugging_utils import (
model_addition_debugger_context,
)
from .modeling_layers import GradientCheckpointingLayer
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from .modeling_utils import AttentionInterface, PreTrainedModel

View File

@ -0,0 +1,48 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch.nn as nn
class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing.
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
Important:
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
Example:
```python
>>> # Correct - hidden_states passed as positional arg
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
>>> # Incorrect - hidden_states passed as keyword arg
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
```
"""
gradient_checkpointing = False
def __call__(self, *args, **kwargs):
if self.gradient_checkpointing and self.training:
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
return super().__call__(*args, **kwargs)

View File

@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN
@ -28,6 +27,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -590,7 +590,7 @@ class AriaTextAttention(nn.Module):
return attn_output, attn_weights
class AriaTextDecoderLayer(nn.Module):
class AriaTextDecoderLayer(GradientCheckpointingLayer):
"""
Aria Text Decoder Layer.
@ -940,30 +940,17 @@ class AriaTextModel(AriaTextPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -27,7 +27,6 @@
# This file is based on the LLama model definition file in transformers
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -38,6 +37,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -301,7 +301,7 @@ class CohereAttention(nn.Module):
return attn_output, attn_weights
class CohereDecoderLayer(nn.Module):
class CohereDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: CohereConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -589,30 +589,17 @@ class CohereModel(CoherePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -30,6 +30,7 @@ from torch import nn
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@ -209,7 +210,7 @@ class CohereAttention(LlamaAttention):
return attn_output, attn_weights
class CohereDecoderLayer(nn.Module):
class CohereDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: CohereConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -29,6 +28,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -290,7 +290,7 @@ class Cohere2MLP(nn.Module):
return down_proj
class Cohere2DecoderLayer(nn.Module):
class Cohere2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -612,28 +612,16 @@ class Cohere2Model(Cohere2PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple
import torch
@ -526,28 +525,16 @@ class Cohere2Model(Gemma2Model):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -5,7 +5,6 @@
# modular_deepseek_v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -18,6 +17,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -454,7 +454,7 @@ class DeepseekV3Attention(nn.Module):
return attn_output, attn_weights
class DeepseekV3DecoderLayer(nn.Module):
class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DeepseekV3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -734,30 +734,17 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
@ -38,6 +37,7 @@ from ...modeling_flash_attention_utils import (
_flash_attention_forward,
flash_attn_supports_top_left_mask,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -526,7 +526,7 @@ DIFFLLAMA_ATTENTION_CLASSES = {
}
class DiffLlamaDecoderLayer(nn.Module):
class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -837,30 +837,17 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -21,7 +21,7 @@
# limitations under the License.
import math
from functools import cached_property, partial
from functools import cached_property
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -34,6 +34,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -248,7 +249,7 @@ class Emu3Attention(nn.Module):
return attn_output, attn_weights
class Emu3DecoderLayer(nn.Module):
class Emu3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Emu3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -1422,30 +1423,17 @@ class Emu3TextModel(Emu3PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -282,7 +283,7 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights
class GemmaDecoderLayer(nn.Module):
class GemmaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -557,29 +558,16 @@ class GemmaModel(GemmaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]

View File

@ -434,29 +434,16 @@ class GemmaModel(LlamaModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -31,6 +30,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -300,7 +300,7 @@ class GlmRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class GlmDecoderLayer(nn.Module):
class GlmDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GlmConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -572,30 +572,17 @@ class GlmModel(GlmPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -31,6 +30,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -83,7 +83,7 @@ class Glm4MLP(nn.Module):
return self.down_proj(up_states)
class Glm4DecoderLayer(nn.Module):
class Glm4DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Glm4Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -580,30 +580,17 @@ class Glm4Model(Glm4PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -15,11 +15,11 @@
# limitations under the License.
from typing import Optional, Tuple, Union
import torch.nn as nn
import torch.utils.checkpoint
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import CausalLMOutputWithPast
from ...processing_utils import Unpack
from ...utils import LossKwargs, logging
@ -43,7 +43,7 @@ class Glm4MLP(Phi3MLP):
pass
class Glm4DecoderLayer(nn.Module):
class Glm4DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Glm4Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -31,6 +30,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -243,7 +243,7 @@ class GraniteMLP(nn.Module):
return down_proj
class GraniteDecoderLayer(nn.Module):
class GraniteDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GraniteConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -572,30 +572,17 @@ class GraniteModel(GranitePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@ -192,30 +191,17 @@ class GraniteModel(LlamaModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -20,7 +20,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -31,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -285,7 +285,7 @@ class HeliumAttention(nn.Module):
return attn_output, attn_weights
class HeliumDecoderLayer(nn.Module):
class HeliumDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
@ -557,30 +557,17 @@ class HeliumModel(HeliumPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -31,6 +31,7 @@ from ...cache_utils import Cache
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
from ...generation.utils import GenerateDecoderOnlyOutput
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@ -429,7 +430,7 @@ class JanusVisionMLP(nn.Module):
return hidden_states
class JanusVisionEncoderLayer(nn.Module):
class JanusVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -536,19 +537,12 @@ class JanusVisionEncoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -17,7 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -29,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -290,7 +290,7 @@ class LlamaAttention(nn.Module):
return attn_output, attn_weights
class LlamaDecoderLayer(nn.Module):
class LlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -562,30 +562,17 @@ class LlamaModel(LlamaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_mistral.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -16,6 +15,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -228,7 +228,7 @@ class MistralRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MistralDecoderLayer(nn.Module):
class MistralDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -532,30 +532,17 @@ class MistralModel(MistralPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
@ -34,6 +33,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
@ -351,7 +351,7 @@ class MoonshineRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MoonshineEncoderLayer(nn.Module):
class MoonshineEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: MoonshineConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -410,7 +410,7 @@ class MoonshineEncoderLayer(nn.Module):
return outputs
class MoonshineDecoderLayer(nn.Module):
class MoonshineDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
@ -668,27 +668,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
False,
None,
position_embeddings,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@ -912,33 +899,19 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
encoder_hidden_states,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -27,6 +26,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
@ -428,7 +428,7 @@ class MoonshineEncoderLayer(LlamaDecoderLayer):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
class MoonshineDecoderLayer(nn.Module):
class MoonshineDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
@ -686,27 +686,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
False,
None,
position_embeddings,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@ -832,33 +819,19 @@ class MoonshineDecoder(LlamaModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
encoder_hidden_states,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -16,6 +15,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -229,7 +229,7 @@ class OlmoAttention(nn.Module):
return attn_output, attn_weights
class OlmoDecoderLayer(nn.Module):
class OlmoDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: OlmoConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -532,30 +532,17 @@ class OlmoModel(OlmoPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -16,6 +15,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -233,7 +233,7 @@ class Olmo2MLP(nn.Module):
return down_proj
class Olmo2DecoderLayer(nn.Module):
class Olmo2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Olmo2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -538,30 +538,17 @@ class Olmo2Model(Olmo2PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -20,7 +20,6 @@
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -32,6 +31,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -257,7 +257,7 @@ class Phi3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Phi3DecoderLayer(nn.Module):
class Phi3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Phi3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -587,30 +587,17 @@ class Phi3Model(Phi3PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -36,6 +36,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
@ -154,7 +155,7 @@ class Phi4MultimodalVisionAttention(nn.Module):
return attn_output, attn_weights
class Phi4MultimodalVisionEncoderLayer(nn.Module):
class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -262,19 +263,12 @@ class Phi4MultimodalVisionEncoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1213,14 +1207,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
for layer in self.encoders:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
)
else:
hidden_states = layer(hidden_states, attention_mask)
hidden_states = layer(hidden_states, attention_mask)
if unfolded:
embed_dim = hidden_states.shape[-1]
@ -1483,7 +1470,7 @@ class Phi4MultimodalAttention(nn.Module):
return attn_output, attn_weights
class Phi4MultimodalDecoderLayer(nn.Module):
class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -1885,30 +1872,17 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -1265,14 +1265,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
for layer in self.encoders:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
)
else:
hidden_states = layer(hidden_states, attention_mask)
hidden_states = layer(hidden_states, attention_mask)
if unfolded:
embed_dim = hidden_states.shape[-1]
@ -1655,30 +1648,17 @@ class Phi4MultimodalModel(Phi3Model, nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_qwen2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -16,6 +15,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -236,7 +236,7 @@ class Qwen2RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2DecoderLayer(nn.Module):
class Qwen2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -545,30 +545,17 @@ class Qwen2Model(Qwen2PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -31,6 +30,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -261,7 +261,7 @@ class Qwen3Attention(nn.Module):
return attn_output, attn_weights
class Qwen3DecoderLayer(nn.Module):
class Qwen3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -572,30 +572,17 @@ class Qwen3Model(Qwen3PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -28,6 +28,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
@ -465,7 +466,7 @@ class SiglipMLP(nn.Module):
return hidden_states
class SiglipEncoderLayer(nn.Module):
class SiglipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
@ -743,19 +744,12 @@ class SiglipEncoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -32,6 +32,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
@ -356,7 +357,7 @@ class Siglip2MLP(nn.Module):
return hidden_states
class Siglip2EncoderLayer(nn.Module):
class Siglip2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
@ -462,19 +463,12 @@ class Siglip2Encoder(nn.Module):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -34,6 +34,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -230,7 +231,7 @@ class Starcoder2Attention(nn.Module):
return attn_output, attn_weights
class Starcoder2DecoderLayer(nn.Module):
class Starcoder2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Starcoder2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

View File

@ -542,6 +542,13 @@ def model_addition_debugger_context(*args, **kwargs):
requires_backends(model_addition_debugger_context, ["torch"])
class GradientCheckpointingLayer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
ROPE_INIT_FUNCTIONS = None