Allow only textual inputs to VisualBert (#13687)

This commit is contained in:
Gunjan Chhablani 2021-09-22 21:21:53 +05:30 committed by GitHub
parent 93624bfee9
commit 50c746eeb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -778,29 +778,30 @@ class VisualBertModel(VisualBertPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_embeds is None:
raise ValueError(
f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."
)
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
visual_input_shape = visual_embeds.size()[:-1]
if visual_embeds is not None:
visual_input_shape = visual_embeds.size()[:-1]
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if visual_attention_mask is None:
if visual_embeds is not None and visual_attention_mask is None:
visual_attention_mask = torch.ones(visual_input_shape, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head