Fix OwlViT torchscript tests (#18347)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-07-29 10:36:04 +02:00 committed by GitHub
parent a4ee463d95
commit a64bcb564d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1153,7 +1153,6 @@ class OwlViTClassPredictionHead(nn.Module):
class OwlViTForObjectDetection(OwlViTPreTrainedModel):
config_class = OwlViTConfig
main_input_name = "pixel_values"
def __init__(self, config: OwlViTConfig):
super().__init__(config)
@ -1246,8 +1245,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
def image_text_embedder(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.FloatTensor:
@ -1284,8 +1283,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -1338,8 +1337,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
if output_hidden_states:
outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@ -1350,8 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
# Embed images and text queries
feature_map, query_embeds = self.image_text_embedder(
pixel_values=pixel_values,
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
@ -1374,7 +1373,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes = self.box_predictor(image_feats, feature_map)
if not return_dict:
return (
output = (
pred_logits,
pred_boxes,
query_embeds,
@ -1383,6 +1382,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
text_model_last_hidden_states,
vision_model_last_hidden_states,
)
output = tuple(x for x in output if x is not None)
return output
return OwlViTObjectDetectionOutput(
image_embeds=feature_map,