Use Gradient Checkpointing Layer in Jamba & Blip Related Models (#38310)

* Use gradient checkpointing class in blip classes

* Use gradient checkpointing class in jamba/bamba
This commit is contained in:
Alex Brooks 2025-05-23 11:35:25 -06:00 committed by GitHub
parent 53fb245eb6
commit e64ed0304c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 128 additions and 229 deletions

View File

@ -24,7 +24,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, TypedDict, Union from typing import Callable, Optional, Tuple, TypedDict, Union
import torch import torch
@ -38,6 +37,7 @@ from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -938,7 +938,7 @@ class BambaRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class BambaDecoderLayer(nn.Module): class BambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
super().__init__() super().__init__()
@ -1154,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **kwargs), attention_mask=layer_mask,
hidden_states, position_ids=position_ids,
layer_mask, past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, position_embeddings=position_embeddings,
cache_position, **kwargs,
position_embeddings, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -19,7 +19,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch Bamba model.""" """PyTorch Bamba model."""
from functools import partial
from typing import Optional, Tuple, TypedDict, Union from typing import Optional, Tuple, TypedDict, Union
import torch import torch
@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
partial(decoder_layer.__call__, **kwargs), attention_mask=layer_mask,
hidden_states, position_ids=position_ids,
layer_mask, past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, use_cache=use_cache,
output_attentions, cache_position=cache_position,
use_cache, position_embeddings=position_embeddings,
cache_position, **kwargs,
position_embeddings, )
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch.nn.functional import normalize
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, logging, torch_int
@ -405,7 +406,7 @@ class BlipMLP(nn.Module):
return hidden_states return hidden_states
class BlipEncoderLayer(nn.Module): class BlipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlipConfig): def __init__(self, config: BlipConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -548,19 +549,12 @@ class BlipEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = encoder_layer(
encoder_layer.__call__, hidden_states,
hidden_states, attention_mask,
attention_mask, output_attentions=output_attentions,
output_attentions, )
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module):
return hidden_states return hidden_states
class BlipTextLayer(nn.Module): class BlipTextLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_num): def __init__(self, config, layer_num):
super().__init__() super().__init__()
self.config = config self.config = config
@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: layer_outputs = layer_module(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
layer_module.__call__, attention_mask,
hidden_states, layer_head_mask,
attention_mask, encoder_hidden_states,
layer_head_mask, encoder_attention_mask,
encoder_hidden_states, past_key_value,
encoder_attention_mask, output_attentions,
past_key_value, )
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -373,7 +374,7 @@ class Blip2MLP(nn.Module):
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2
class Blip2EncoderLayer(nn.Module): class Blip2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = encoder_layer(
encoder_layer.__call__, hidden_states,
hidden_states, attention_mask,
attention_mask, output_attentions=output_attentions,
output_attentions, )
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module):
return hidden_states return hidden_states
class Blip2QFormerLayer(nn.Module): class Blip2QFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
if use_cache: logger.warning(
logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
) )
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module):
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
class InstructBlipEncoderLayer(nn.Module): class InstructBlipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InstructBlipConfig): def __init__(self, config: InstructBlipConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = encoder_layer(
encoder_layer.__call__, hidden_states,
hidden_states, attention_mask,
attention_mask, output_attentions=output_attentions,
output_attentions, )
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module):
return hidden_states return hidden_states
class InstructBlipQFormerLayer(nn.Module): class InstructBlipQFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
if use_cache: logger.warning(
logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
) )
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -29,6 +29,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module):
return hidden_states return hidden_states
class InstructBlipVideoEncoderLayer(nn.Module): class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: InstructBlipVideoConfig): def __init__(self, config: InstructBlipVideoConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = encoder_layer(
encoder_layer.__call__, hidden_states,
hidden_states, attention_mask,
attention_mask, output_attentions=output_attentions,
output_attentions, )
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module):
return hidden_states return hidden_states
class InstructBlipVideoQFormerLayer(nn.Module): class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
if use_cache: logger.warning(
logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
) )
use_cache = False
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ o
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, logging from ...utils import auto_docstring, can_return_tuple, logging
@ -894,7 +895,7 @@ class JambaSparseMoeBlock(nn.Module):
return final_hidden_states, router_logits return final_hidden_states, router_logits
class JambaAttentionDecoderLayer(nn.Module): class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JambaConfig, layer_idx: int): def __init__(self, config: JambaConfig, layer_idx: int):
super().__init__() super().__init__()
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
@ -976,7 +977,7 @@ class JambaAttentionDecoderLayer(nn.Module):
return outputs return outputs
class JambaMambaDecoderLayer(nn.Module): class JambaMambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JambaConfig, layer_idx: int): def __init__(self, config: JambaConfig, layer_idx: int):
super().__init__() super().__init__()
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
@ -1186,29 +1187,16 @@ class JambaModel(JambaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: layer_outputs = decoder_layer(
layer_outputs = self._gradient_checkpointing_func( hidden_states,
decoder_layer.__call__, attention_mask=layer_mask,
hidden_states, position_ids=position_ids,
layer_mask, past_key_value=past_key_values,
position_ids, output_attentions=output_attentions,
past_key_values, output_router_logits=output_router_logits,
output_attentions, use_cache=use_cache,
output_router_logits, cache_position=cache_position,
use_cache, )
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]