mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Use Gradient Checkpointing Layer in Jamba & Blip Related Models (#38310)
* Use gradient checkpointing class in blip classes * Use gradient checkpointing class in jamba/bamba
This commit is contained in:
parent
53fb245eb6
commit
e64ed0304c
@ -24,7 +24,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
@ -38,6 +37,7 @@ from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
@ -938,7 +938,7 @@ class BambaRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class BambaDecoderLayer(nn.Module):
|
||||
class BambaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
|
||||
super().__init__()
|
||||
|
||||
@ -1154,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -19,7 +19,6 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Bamba model."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch.nn.functional import normalize
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
@ -405,7 +406,7 @@ class BlipMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipEncoderLayer(nn.Module):
|
||||
class BlipEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -548,19 +549,12 @@ class BlipEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipTextLayer(nn.Module):
|
||||
class BlipTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_num):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -373,7 +374,7 @@ class Blip2MLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2
|
||||
class Blip2EncoderLayer(nn.Module):
|
||||
class Blip2EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
class Blip2QFormerLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module):
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
|
||||
class InstructBlipEncoderLayer(nn.Module):
|
||||
class InstructBlipEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: InstructBlipConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InstructBlipQFormerLayer(nn.Module):
|
||||
class InstructBlipQFormerLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -29,6 +29,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InstructBlipVideoEncoderLayer(nn.Module):
|
||||
class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: InstructBlipVideoConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module):
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerLayer(nn.Module):
|
||||
class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
|
@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ o
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
@ -894,7 +895,7 @@ class JambaSparseMoeBlock(nn.Module):
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
class JambaAttentionDecoderLayer(nn.Module):
|
||||
class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: JambaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
@ -976,7 +977,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class JambaMambaDecoderLayer(nn.Module):
|
||||
class JambaMambaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: JambaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
@ -1186,29 +1187,16 @@ class JambaModel(JambaPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user