mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added error when sequence length is bigger than max_position_embeddings (#32156)
* Added error when sequence length is bigger than max_position_embeddings * Fixed formatting * Fixed bug * Changed copies to match * Fixed bug * Applied suggestions * Removed redundant code * Fixed bugs * Bug fix * Bug fix * Added requested Changes * Fixed bug * Fixed unwanted change * Fixed unwanated changes * Fixed formatting
This commit is contained in:
parent
1211e616a4
commit
0aaf124fb9
@ -309,6 +309,13 @@ class BlipTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
@ -277,6 +277,13 @@ class CLIPTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
@ -244,6 +244,13 @@ class CLIPSegTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
@ -446,6 +446,13 @@ class GroupViTTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
@ -340,6 +340,13 @@ class SiglipTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
@ -203,6 +203,13 @@ class XCLIPTextEmbeddings(nn.Module):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
Loading…
Reference in New Issue
Block a user