mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from transformers.models.bert.modeling_bert import BertModel
|
|
|
|
from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
|
|
|
|
|
class DummyBertModel(BertModel):
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
return super().forward(input_ids)
|