copy fixes

This commit is contained in:
Vasqu 2025-07-11 15:59:42 +02:00
parent d9f0a8a304
commit 367fe5d043
2 changed files with 2 additions and 4 deletions

View File

@ -481,12 +481,10 @@ class LiltLayer(GradientCheckpointingLayer):
class LiltEncoder(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Lilt
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,

View File

@ -414,7 +414,7 @@ class TapasAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
# Copied from transformers.models.rembert.modeling_rembert.RemBertAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@ -485,7 +485,7 @@ class TapasLayer(GradientCheckpointingLayer):
self.intermediate = TapasIntermediate(config)
self.output = TapasOutput(config)
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
# Copied from transformers.models.rembert.modeling_rembert.RemBertLayer.forward
def forward(
self,
hidden_states: torch.Tensor,