update modeling

This commit is contained in:
ducviet00 2025-06-28 21:27:31 +07:00
parent af63c6869c
commit e557dc695d

View File

@ -37,6 +37,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -760,7 +761,7 @@ class Florence2LanguageAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class Florence2LanguageEncoderLayer(nn.Module):
class Florence2LanguageEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Florence2LanguageConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -831,7 +832,7 @@ class Florence2LanguageEncoderLayer(nn.Module):
return outputs
class Florence2LanguageDecoderLayer(nn.Module):
class Florence2LanguageDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Florence2LanguageConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -1351,21 +1352,12 @@ class Florence2LanguageEncoder(Florence2LanguagePreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1615,35 +1607,18 @@ class Florence2LanguageDecoder(Florence2LanguagePreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache: