add bbox input validation (#26294)

This commit is contained in:
Jinho Park 2023-09-20 23:48:35 +09:00 committed by GitHub
parent 245532065d
commit 00247ea0de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -877,6 +877,9 @@ class BrosModel(BrosPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if bbox is None:
raise ValueError("You have to specify bbox")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
@ -924,13 +927,11 @@ class BrosModel(BrosPreTrainedModel):
past_key_values_length=past_key_values_length,
)
bbox_position_embeddings = None
if bbox is not None:
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
if bbox.shape[-1] == 4:
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
scaled_bbox = bbox * self.config.bbox_scale
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
if bbox.shape[-1] == 4:
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
scaled_bbox = bbox * self.config.bbox_scale
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
encoder_outputs = self.encoder(
embedding_output,