mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix OwlViT torchscript tests (#18347)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a4ee463d95
commit
a64bcb564d
@ -1153,7 +1153,6 @@ class OwlViTClassPredictionHead(nn.Module):
|
||||
|
||||
class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
config_class = OwlViTConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: OwlViTConfig):
|
||||
super().__init__(config)
|
||||
@ -1246,8 +1245,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
|
||||
def image_text_embedder(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.Tensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
@ -1284,8 +1283,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.Tensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -1338,8 +1337,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
|
||||
if output_hidden_states:
|
||||
outputs = self.owlvit(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -1350,8 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
|
||||
# Embed images and text queries
|
||||
feature_map, query_embeds = self.image_text_embedder(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -1374,7 +1373,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
pred_boxes = self.box_predictor(image_feats, feature_map)
|
||||
|
||||
if not return_dict:
|
||||
return (
|
||||
output = (
|
||||
pred_logits,
|
||||
pred_boxes,
|
||||
query_embeds,
|
||||
@ -1383,6 +1382,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
text_model_last_hidden_states,
|
||||
vision_model_last_hidden_states,
|
||||
)
|
||||
output = tuple(x for x in output if x is not None)
|
||||
return output
|
||||
|
||||
return OwlViTObjectDetectionOutput(
|
||||
image_embeds=feature_map,
|
||||
|
Loading…
Reference in New Issue
Block a user