From e64ed0304c53798af2a3c4c5882473a0b3d28e37 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 23 May 2025 11:35:25 -0600 Subject: [PATCH] Use Gradient Checkpointing Layer in Jamba & Blip Related Models (#38310) * Use gradient checkpointing class in blip classes * Use gradient checkpointing class in jamba/bamba --- .../models/bamba/modeling_bamba.py | 39 ++++-------- .../models/bamba/modular_bamba.py | 36 ++++------- src/transformers/models/blip/modeling_blip.py | 22 +++---- .../models/blip/modeling_blip_text.py | 33 ++++------ .../models/blip_2/modeling_blip_2.py | 63 +++++++------------ .../instructblip/modeling_instructblip.py | 63 +++++++------------ .../modeling_instructblipvideo.py | 63 +++++++------------ .../models/jamba/modeling_jamba.py | 38 ++++------- 8 files changed, 128 insertions(+), 229 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 918782826dd..1b8e12d1c3b 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -24,7 +24,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, TypedDict, Union import torch @@ -38,6 +37,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -938,7 +938,7 @@ class BambaRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class BambaDecoderLayer(nn.Module): +class BambaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): super().__init__() @@ -1154,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **kwargs), - hidden_states, - layer_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=layer_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7e0090b3945..9db52ebfbc5 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -19,7 +19,6 @@ # limitations under the License. """PyTorch Bamba model.""" -from functools import partial from typing import Optional, Tuple, TypedDict, Union import torch @@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **kwargs), - hidden_states, - layer_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=layer_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 548a362ebfd..356f48eaf94 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -25,6 +25,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -405,7 +406,7 @@ class BlipMLP(nn.Module): return hidden_states -class BlipEncoderLayer(nn.Module): +class BlipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlipConfig): super().__init__() self.embed_dim = config.hidden_size @@ -548,19 +549,12 @@ class BlipEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index f26f269c7b9..ffbca32eb9d 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module): return hidden_states -class BlipTextLayer(nn.Module): +class BlipTextLayer(GradientCheckpointingLayer): def __init__(self, config, layer_num): super().__init__() self.config = config @@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 3ca38af6add..ea591bf730d 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -373,7 +374,7 @@ class Blip2MLP(nn.Module): # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 -class Blip2EncoderLayer(nn.Module): +class Blip2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Blip2Config): super().__init__() self.embed_dim = config.hidden_size @@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module): return hidden_states -class Blip2QFormerLayer(nn.Module): +class Blip2QFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index c90d22f012d..8018dbe76a9 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module): # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip -class InstructBlipEncoderLayer(nn.Module): +class InstructBlipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InstructBlipConfig): super().__init__() self.embed_dim = config.hidden_size @@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module): return hidden_states -class InstructBlipQFormerLayer(nn.Module): +class InstructBlipQFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index b9f40deffef..cc18bbf90b6 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -29,6 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module): return hidden_states -class InstructBlipVideoEncoderLayer(nn.Module): +class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InstructBlipVideoConfig): super().__init__() self.embed_dim = config.hidden_size @@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module): return hidden_states -class InstructBlipVideoQFormerLayer(nn.Module): +class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 90127ba70f9..d60190161ef 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ o from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, logging @@ -894,7 +895,7 @@ class JambaSparseMoeBlock(nn.Module): return final_hidden_states, router_logits -class JambaAttentionDecoderLayer(nn.Module): +class JambaAttentionDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: JambaConfig, layer_idx: int): super().__init__() num_experts = config.layers_num_experts[layer_idx] @@ -976,7 +977,7 @@ class JambaAttentionDecoderLayer(nn.Module): return outputs -class JambaMambaDecoderLayer(nn.Module): +class JambaMambaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: JambaConfig, layer_idx: int): super().__init__() num_experts = config.layers_num_experts[layer_idx] @@ -1186,29 +1187,16 @@ class JambaModel(JambaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - layer_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=layer_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0]