mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
quel enfer
This commit is contained in:
parent
3ac6c52f34
commit
0b119ffb1f
@ -24,6 +24,8 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.utils.generic import check_model_inputs
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
@ -342,13 +344,7 @@ class T5GemmaCrossAttention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
@ -720,7 +716,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -793,15 +789,13 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -841,7 +835,6 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
|
||||
|
||||
if not isinstance(self_attn_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
@ -849,15 +842,12 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
|
||||
}
|
||||
# Create the masks
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# Attention masks: Cross attention
|
||||
if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": encoder_hidden_states,
|
||||
@ -872,15 +862,9 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@ -946,8 +930,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Seq2SeqModelOutput:
|
||||
@ -962,8 +944,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -978,8 +958,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1021,8 +999,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutput:
|
||||
encoder_outputs = self.encoder(
|
||||
@ -1030,8 +1006,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
return encoder_outputs
|
||||
@ -1074,23 +1048,18 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
@ -1130,8 +1099,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
@ -1202,21 +1169,16 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> SequenceClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1254,8 +1216,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1266,8 +1226,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
@ -1350,21 +1308,16 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1402,8 +1355,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1414,8 +1365,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
|
@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.utils.generic import check_model_inputs
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
@ -290,13 +292,7 @@ class T5GemmaCrossAttention(Gemma2Attention):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
@ -582,7 +578,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -655,15 +651,13 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -703,7 +697,6 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
|
||||
|
||||
if not isinstance(self_attn_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
@ -711,15 +704,12 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
|
||||
}
|
||||
# Create the masks
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# Attention masks: Cross attention
|
||||
if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": encoder_hidden_states,
|
||||
@ -734,15 +724,9 @@ class T5GemmaDecoder(T5GemmaEncoder):
|
||||
),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@ -808,8 +792,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Seq2SeqModelOutput:
|
||||
@ -824,8 +806,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -840,8 +820,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -883,8 +861,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutput:
|
||||
encoder_outputs = self.encoder(
|
||||
@ -892,8 +868,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
return encoder_outputs
|
||||
@ -936,23 +910,18 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
@ -992,8 +961,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
@ -1064,21 +1031,16 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> SequenceClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1116,8 +1078,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1128,8 +1088,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
@ -1212,21 +1170,16 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
# encoder
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
# decoder
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
# others
|
||||
encoder_outputs: Optional[BaseModelOutput] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1264,8 +1217,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1276,8 +1227,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
|
Loading…
Reference in New Issue
Block a user