Use Gradient Checkpointing Layer in Jamba & Blip Related Models (#38310)

* Use gradient checkpointing class in blip classes

* Use gradient checkpointing class in jamba/bamba
This commit is contained in:
Alex Brooks 2025-05-23 11:35:25 -06:00 committed by GitHub
parent 53fb245eb6
commit e64ed0304c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 128 additions and 229 deletions

View File

@ -24,7 +24,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, TypedDict, Union
import torch
@ -38,6 +37,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -938,7 +938,7 @@ class BambaRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class BambaDecoderLayer(nn.Module):
class BambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
super().__init__()
@ -1154,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -19,7 +19,6 @@
# limitations under the License.
"""PyTorch Bamba model."""
from functools import partial
from typing import Optional, Tuple, TypedDict, Union
import torch
@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch.nn.functional import normalize
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -405,7 +406,7 @@ class BlipMLP(nn.Module):
return hidden_states
class BlipEncoderLayer(nn.Module):
class BlipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlipConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -548,19 +549,12 @@ class BlipEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module):
return hidden_states
class BlipTextLayer(nn.Module):
class BlipTextLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -373,7 +374,7 @@ class Blip2MLP(nn.Module):
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2
class Blip2EncoderLayer(nn.Module):
class Blip2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Blip2Config):
super().__init__()
self.embed_dim = config.hidden_size
@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module):
return hidden_states
class Blip2QFormerLayer(nn.Module):
class Blip2QFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module):
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
class InstructBlipEncoderLayer(nn.Module):
class InstructBlipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InstructBlipConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module):
return hidden_states
class InstructBlipQFormerLayer(nn.Module):
class InstructBlipQFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -29,6 +29,7 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module):
return hidden_states
class InstructBlipVideoEncoderLayer(nn.Module):
class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InstructBlipVideoConfig):
super().__init__()
self.embed_dim = config.hidden_size
@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module):
return hidden_states
class InstructBlipVideoQFormerLayer(nn.Module):
class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:

View File

@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ o
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, logging
@ -894,7 +895,7 @@ class JambaSparseMoeBlock(nn.Module):
return final_hidden_states, router_logits
class JambaAttentionDecoderLayer(nn.Module):
class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JambaConfig, layer_idx: int):
super().__init__()
num_experts = config.layers_num_experts[layer_idx]
@ -976,7 +977,7 @@ class JambaAttentionDecoderLayer(nn.Module):
return outputs
class JambaMambaDecoderLayer(nn.Module):
class JambaMambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JambaConfig, layer_idx: int):
super().__init__()
num_experts = config.layers_num_experts[layer_idx]
@ -1186,29 +1187,16 @@ class JambaModel(JambaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
layer_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]