mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
add bbox input validation (#26294)
This commit is contained in:
parent
245532065d
commit
00247ea0de
@ -877,6 +877,9 @@ class BrosModel(BrosPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if bbox is None:
|
||||||
|
raise ValueError("You have to specify bbox")
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
@ -924,13 +927,11 @@ class BrosModel(BrosPreTrainedModel):
|
|||||||
past_key_values_length=past_key_values_length,
|
past_key_values_length=past_key_values_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
bbox_position_embeddings = None
|
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||||
if bbox is not None:
|
if bbox.shape[-1] == 4:
|
||||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
|
||||||
if bbox.shape[-1] == 4:
|
scaled_bbox = bbox * self.config.bbox_scale
|
||||||
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
|
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
|
||||||
scaled_bbox = bbox * self.config.bbox_scale
|
|
||||||
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
|
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
|
Loading…
Reference in New Issue
Block a user