Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)

* support passing flash_attn_kwargs when gradient_checkpointing is enabled

* make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
efsotr 2025-03-31 16:53:02 +08:00 committed by GitHub
parent bd41b9c1ac
commit 2b4734bd49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 58 additions and 30 deletions

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_dummy.py file directly. One of our CI enforces this. # modular_dummy.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_multimodal1.py file directly. One of our CI enforces this. # modular_multimodal1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -544,7 +545,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN from ...activations import ACT2FN
@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -27,6 +27,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -613,7 +614,7 @@ class CohereModel(CoherePreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -634,7 +635,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -533,7 +534,7 @@ class Cohere2Model(Gemma2Model):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask,

View File

@ -5,6 +5,7 @@
# modular_deepseek_v3.py file directly. One of our CI enforces this. # modular_deepseek_v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math import math
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -759,7 +760,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from functools import partial
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -852,7 +853,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -21,7 +21,7 @@
# limitations under the License. # limitations under the License.
import math import math
from functools import cached_property from functools import cached_property, partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -1439,7 +1439,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -645,7 +646,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -491,7 +492,7 @@ class Gemma2Model(GemmaModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask,

View File

@ -22,6 +22,7 @@
import copy import copy
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -732,7 +733,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings_global, position_embeddings_global,
position_embeddings_local, position_embeddings_local,

View File

@ -16,6 +16,7 @@
import copy import copy
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
@ -662,7 +663,7 @@ class Gemma3TextModel(Gemma2Model):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings_global, position_embeddings_global,
position_embeddings_local, position_embeddings_local,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -594,7 +595,7 @@ class GlmModel(GlmPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -593,7 +594,7 @@ class GraniteModel(GranitePreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -185,7 +186,7 @@ class GraniteModel(LlamaModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -20,6 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -581,7 +582,7 @@ class HeliumModel(HeliumPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -583,7 +584,7 @@ class LlamaModel(LlamaPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_mistral.py file directly. One of our CI enforces this. # modular_mistral.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -548,7 +549,7 @@ class MistralModel(MistralPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -24,6 +24,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -672,7 +673,7 @@ class MixtralModel(MixtralPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mixtral model.""" """PyTorch Mixtral model."""
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -400,7 +401,7 @@ class MixtralModel(MistralModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import numpy as np import numpy as np
@ -936,7 +937,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
encoder_hidden_states, encoder_hidden_states,

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -832,7 +833,7 @@ class MoonshineDecoder(LlamaModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
encoder_hidden_states, encoder_hidden_states,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_olmo.py file directly. One of our CI enforces this. # modular_olmo.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -559,7 +560,7 @@ class OlmoModel(OlmoPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_olmo2.py file directly. One of our CI enforces this. # modular_olmo2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -560,7 +561,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_phi.py file directly. One of our CI enforces this. # modular_phi.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -553,7 +554,7 @@ class PhiModel(PhiPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -243,7 +244,7 @@ class PhiModel(LlamaModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -20,6 +20,7 @@
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -623,7 +624,7 @@ class Phi3Model(Phi3PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,

View File

@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_qwen2.py file directly. One of our CI enforces this. # modular_qwen2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
@ -561,7 +562,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,