mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Use Gradient Checkpointing Layer in Jamba & Blip Related Models (#38310)
* Use gradient checkpointing class in blip classes * Use gradient checkpointing class in jamba/bamba
This commit is contained in:
parent
53fb245eb6
commit
e64ed0304c
@ -24,7 +24,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, TypedDict, Union
|
from typing import Callable, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -38,6 +37,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@ -938,7 +938,7 @@ class BambaRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
class BambaDecoderLayer(nn.Module):
|
class BambaDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
|
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1154,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **kwargs),
|
attention_mask=layer_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
layer_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=layer_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Bamba model."""
|
"""PyTorch Bamba model."""
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional, Tuple, TypedDict, Union
|
from typing import Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **kwargs),
|
attention_mask=layer_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
layer_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=layer_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from torch.nn.functional import normalize
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||||
@ -405,7 +406,7 @@ class BlipMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BlipEncoderLayer(nn.Module):
|
class BlipEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: BlipConfig):
|
def __init__(self, config: BlipConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -548,19 +549,12 @@ class BlipEncoder(nn.Module):
|
|||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BlipTextLayer(nn.Module):
|
class BlipTextLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config, layer_num):
|
def __init__(self, config, layer_num):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module):
|
|||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = layer_module(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
layer_module.__call__,
|
attention_mask,
|
||||||
hidden_states,
|
layer_head_mask,
|
||||||
attention_mask,
|
encoder_hidden_states,
|
||||||
layer_head_mask,
|
encoder_attention_mask,
|
||||||
encoder_hidden_states,
|
past_key_value,
|
||||||
encoder_attention_mask,
|
output_attentions,
|
||||||
past_key_value,
|
)
|
||||||
output_attentions,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -373,7 +374,7 @@ class Blip2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2
|
||||||
class Blip2EncoderLayer(nn.Module):
|
class Blip2EncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module):
|
|||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Blip2QFormerLayer(nn.Module):
|
class Blip2QFormerLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module):
|
|||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||||
if use_cache:
|
logger.warning(
|
||||||
logger.warning(
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
layer_module.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
query_length,
|
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
query_length,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -25,6 +25,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
|
||||||
class InstructBlipEncoderLayer(nn.Module):
|
class InstructBlipEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: InstructBlipConfig):
|
def __init__(self, config: InstructBlipConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module):
|
|||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipQFormerLayer(nn.Module):
|
class InstructBlipQFormerLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module):
|
|||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||||
if use_cache:
|
logger.warning(
|
||||||
logger.warning(
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
layer_module.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
query_length,
|
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
query_length,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoEncoderLayer(nn.Module):
|
class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: InstructBlipVideoConfig):
|
def __init__(self, config: InstructBlipVideoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module):
|
|||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoQFormerLayer(nn.Module):
|
class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
|
|||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||||
if use_cache:
|
logger.warning(
|
||||||
logger.warning(
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
layer_module.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
query_length,
|
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
query_length,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ o
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import auto_docstring, can_return_tuple, logging
|
from ...utils import auto_docstring, can_return_tuple, logging
|
||||||
@ -894,7 +895,7 @@ class JambaSparseMoeBlock(nn.Module):
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
|
||||||
class JambaAttentionDecoderLayer(nn.Module):
|
class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: JambaConfig, layer_idx: int):
|
def __init__(self, config: JambaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
@ -976,7 +977,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class JambaMambaDecoderLayer(nn.Module):
|
class JambaMambaDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: JambaConfig, layer_idx: int):
|
def __init__(self, config: JambaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
@ -1186,29 +1187,16 @@ class JambaModel(JambaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=layer_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
layer_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
output_router_logits=output_router_logits,
|
||||||
output_attentions,
|
use_cache=use_cache,
|
||||||
output_router_logits,
|
cache_position=cache_position,
|
||||||
use_cache,
|
)
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=layer_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user