Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)

* support passing flash_attn_kwargs when gradient_checkpointing is enabled

* make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
efsotr 2025-03-31 16:53:02 +08:00 committed by GitHub
parent bd41b9c1ac
commit 2b4734bd49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 58 additions and 30 deletions

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_dummy.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_multimodal1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -544,7 +545,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN
@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -27,6 +27,7 @@
# This file is based on the LLama model definition file in transformers
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -613,7 +614,7 @@ class CohereModel(CoherePreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -634,7 +635,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -533,7 +534,7 @@ class Cohere2Model(Gemma2Model):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,

View File

@ -5,6 +5,7 @@
# modular_deepseek_v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -759,7 +760,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
@ -852,7 +853,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -21,7 +21,7 @@
# limitations under the License.
import math
from functools import cached_property
from functools import cached_property, partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -1439,7 +1439,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -645,7 +646,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -491,7 +492,7 @@ class Gemma2Model(GemmaModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings,
causal_mask,

View File

@ -22,6 +22,7 @@
import copy
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@ -732,7 +733,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings_global,
position_embeddings_local,

View File

@ -16,6 +16,7 @@
import copy
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@ -662,7 +663,7 @@ class Gemma3TextModel(Gemma2Model):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
position_embeddings_global,
position_embeddings_local,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -594,7 +595,7 @@ class GlmModel(GlmPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -593,7 +594,7 @@ class GraniteModel(GranitePreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@ -185,7 +186,7 @@ class GraniteModel(LlamaModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -20,6 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -581,7 +582,7 @@ class HeliumModel(HeliumPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -583,7 +584,7 @@ class LlamaModel(LlamaPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_mistral.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -548,7 +549,7 @@ class MistralModel(MistralPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -24,6 +24,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -672,7 +673,7 @@ class MixtralModel(MixtralPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -19,6 +19,7 @@
# limitations under the License.
"""PyTorch Mixtral model."""
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@ -400,7 +401,7 @@ class MixtralModel(MistralModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
@ -936,7 +937,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
encoder_hidden_states,

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -832,7 +833,7 @@ class MoonshineDecoder(LlamaModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
encoder_hidden_states,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -559,7 +560,7 @@ class OlmoModel(OlmoPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -560,7 +561,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_phi.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -553,7 +554,7 @@ class PhiModel(PhiPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -243,7 +244,7 @@ class PhiModel(LlamaModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -20,6 +20,7 @@
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -623,7 +624,7 @@ class Phi3Model(Phi3PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_qwen2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@ -561,7 +562,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,