mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Allow only textual inputs to VisualBert (#13687)
This commit is contained in:
parent
93624bfee9
commit
50c746eeb7
@ -778,29 +778,30 @@ class VisualBertModel(VisualBertPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if visual_embeds is None:
|
||||
raise ValueError(
|
||||
f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."
|
||||
)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
visual_input_shape = visual_embeds.size()[:-1]
|
||||
if visual_embeds is not None:
|
||||
visual_input_shape = visual_embeds.size()[:-1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
if visual_attention_mask is None:
|
||||
if visual_embeds is not None and visual_attention_mask is None:
|
||||
visual_attention_mask = torch.ones(visual_input_shape, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if visual_embeds is not None:
|
||||
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
|
||||
)
|
||||
|
||||
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
|
||||
)
|
||||
else:
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, [batch_size, input_shape], device
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
Loading…
Reference in New Issue
Block a user