mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 12:08:22 +06:00
albert
This commit is contained in:
parent
11de15bda4
commit
dd7aeca424
@ -14,17 +14,16 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch ALBERT model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
@ -34,17 +33,20 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
is_torch_greater_or_equal_than_2_2,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
|
||||
from .configuration_albert import AlbertConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -199,14 +201,12 @@ class AlbertEmbeddings(nn.Module):
|
||||
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -216,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, : seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -242,6 +242,64 @@ class AlbertEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3))
|
||||
|
||||
# Relative positional embeddings
|
||||
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query.shape[2], key.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
|
||||
|
||||
if module.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
|
||||
attn_weights = attn_weights + relative_position_scores
|
||||
elif module.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
|
||||
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
# Scaling is shifted in case of embeddings being relative
|
||||
attn_weights = attn_weights * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AlbertAttention(nn.Module):
|
||||
def __init__(self, config: AlbertConfig):
|
||||
super().__init__()
|
||||
@ -250,19 +308,22 @@ class AlbertAttention(nn.Module):
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads}"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -271,11 +332,7 @@ class AlbertAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.is_causal = False
|
||||
|
||||
def prune_heads(self, heads: list[int]) -> None:
|
||||
if len(heads) == 0:
|
||||
@ -300,118 +357,49 @@ class AlbertAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# determine input shapes
|
||||
bsz, tgt_len = hidden_states.shape[:-1]
|
||||
src_len = tgt_len
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
|
||||
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# get all proj
|
||||
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
|
||||
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
|
||||
)
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
seq_length = hidden_states.size()[1]
|
||||
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(2, 1).flatten(2)
|
||||
|
||||
projected_context_layer = self.dense(context_layer)
|
||||
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
||||
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
||||
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
||||
|
||||
|
||||
class AlbertSdpaAttention(AlbertAttention):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
|
||||
|
||||
def forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions:
|
||||
logger.warning(
|
||||
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
||||
"the eager attention implementation, but specifying the eager implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout.p,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
# only for relevant for non-absolute positional embeddings
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.size()
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attn_output = self.dense(attn_output)
|
||||
attn_output = self.output_dropout(attn_output)
|
||||
attn_output = self.LayerNorm(hidden_states + attn_output)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
attention_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_layer,
|
||||
key=key_layer,
|
||||
value=value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
attention_output = attention_output.transpose(1, 2)
|
||||
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
|
||||
|
||||
projected_context_layer = self.dense(attention_output)
|
||||
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
||||
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
||||
return (layernormed_context_layer,)
|
||||
|
||||
|
||||
ALBERT_ATTENTION_CLASSES = {
|
||||
"eager": AlbertAttention,
|
||||
"sdpa": AlbertSdpaAttention,
|
||||
}
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AlbertLayer(nn.Module):
|
||||
@ -422,7 +410,7 @@ class AlbertLayer(nn.Module):
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = AlbertAttention(config)
|
||||
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
@ -433,10 +421,8 @@ class AlbertLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
||||
attention_output = self.attention(hidden_states, attention_mask, head_mask)
|
||||
|
||||
ffn_output = apply_chunking_to_forward(
|
||||
self.ff_chunk,
|
||||
@ -473,7 +459,7 @@ class AlbertLayerGroup(nn.Module):
|
||||
layer_attentions = ()
|
||||
|
||||
for layer_index, albert_layer in enumerate(self.albert_layers):
|
||||
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
|
||||
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
|
||||
hidden_states = layer_output[0]
|
||||
|
||||
if output_attentions:
|
||||
@ -548,7 +534,9 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
config_class = AlbertConfig
|
||||
load_tf_weights = load_tf_weights_in_albert
|
||||
base_model_prefix = "albert"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
@ -691,27 +679,16 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
use_sdpa_attention_mask = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
attention_mask = self._update_full_mask(
|
||||
attention_mask,
|
||||
embedding_output
|
||||
)
|
||||
|
||||
if use_sdpa_attention_mask:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -732,6 +709,29 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, None],
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
|
@ -53,12 +53,12 @@ class AlbertModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=32,
|
||||
embedding_size=8,
|
||||
hidden_size=12,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
# this needs to be the same as `num_hidden_layers`!
|
||||
num_hidden_groups=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=16,
|
||||
intermediate_size=20,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
@ -259,7 +259,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
fx_compatible = False # will not be maintained
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
@ -310,6 +310,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
config_and_inputs[0]._attn_implementation = "eager"
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user