mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
add bbox input validation (#26294)
This commit is contained in:
parent
245532065d
commit
00247ea0de
@ -877,6 +877,9 @@ class BrosModel(BrosPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if bbox is None:
|
||||
raise ValueError("You have to specify bbox")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
@ -924,13 +927,11 @@ class BrosModel(BrosPreTrainedModel):
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
bbox_position_embeddings = None
|
||||
if bbox is not None:
|
||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||
if bbox.shape[-1] == 4:
|
||||
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
|
||||
scaled_bbox = bbox * self.config.bbox_scale
|
||||
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
|
||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||
if bbox.shape[-1] == 4:
|
||||
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
|
||||
scaled_bbox = bbox * self.config.bbox_scale
|
||||
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
Loading…
Reference in New Issue
Block a user