mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
update modeling
This commit is contained in:
parent
af63c6869c
commit
e557dc695d
@ -37,6 +37,7 @@ from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -760,7 +761,7 @@ class Florence2LanguageAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Florence2LanguageEncoderLayer(nn.Module):
|
||||
class Florence2LanguageEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Florence2LanguageConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -831,7 +832,7 @@ class Florence2LanguageEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class Florence2LanguageDecoderLayer(nn.Module):
|
||||
class Florence2LanguageDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Florence2LanguageConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -1351,21 +1352,12 @@ class Florence2LanguageEncoder(Florence2LanguagePreTrainedModel):
|
||||
if to_drop:
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1615,35 +1607,18 @@ class Florence2LanguageDecoder(Florence2LanguagePreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
Loading…
Reference in New Issue
Block a user