mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
more more beautiful
This commit is contained in:
parent
5bed246fac
commit
ee8e823914
@ -32,8 +32,6 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
Args:
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
Number of input image channels.
|
||||
num_classes (`int`, *optional*, defaults to 1000):
|
||||
Number of classes for classification head
|
||||
depths (`Tuple[int]`, *optional*, defaults to `(1, 1, 9, 1)`):
|
||||
The depth of the model.
|
||||
patch_size (`Tuple[int]`, *optional*, defaults to `(7, 3, 3, 3)`):
|
||||
@ -44,11 +42,11 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
The patch padding of the image.
|
||||
patch_prenorm (`Tuple[bool]`, *optional*, defaults to `(False, True, True, True)`):
|
||||
Whether to apply layer normalization before the patch embedding layer.
|
||||
embed_dim (`Tuple[int]`, *optional*, defaults to `(256, 512, 1024, 2048)`):
|
||||
embed_dim (`Tuple[int]`, *optional*, defaults to `(128, 256, 512, 1024)`):
|
||||
The dimension of the embedding layer.
|
||||
num_heads (`Tuple[int]`, *optional*, defaults to `(8, 16, 32, 64)`):
|
||||
num_heads (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`):
|
||||
The number of attention heads.
|
||||
num_groups (`Tuple[int]`, *optional*, defaults to `(8, 16, 32, 64)`):
|
||||
num_groups (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`):
|
||||
The number of groups.
|
||||
window_size (`int`, *optional*, defaults to 12):
|
||||
The window size of the model.
|
||||
@ -63,10 +61,8 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the projection layer.
|
||||
visual_temporal_embedding (`dict`, *optional*, defaults to `{'type': 'COSINE', 'max_temporal_embeddings': 100}`):
|
||||
The configuration of the visual temporal embedding.
|
||||
image_pos_embed (`dict`, *optional*, defaults to `{'type': 'learned_abs_2d', 'max_pos_embeddings': 50}`):
|
||||
The configuration of the image position embedding.
|
||||
max_temporal_embeddings (`int`, *optional*, defaults to 100): The configuration of the visual temporal embedding.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 50): The configuration of the image position embedding.
|
||||
image_feature_source (`Tuple[str]`, *optional*, defaults to `('spatial_avg_pool', 'temporal_avg_pool')`):
|
||||
The source of the image feature.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
@ -91,23 +87,22 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_classes=1000,
|
||||
depths=(1, 1, 9, 1),
|
||||
patch_size=(7, 3, 3, 3),
|
||||
patch_stride=(4, 2, 2, 2),
|
||||
patch_padding=(3, 1, 1, 1),
|
||||
patch_prenorm=(False, True, True, True),
|
||||
embed_dim=(256, 512, 1024, 2048),
|
||||
num_heads=(8, 16, 32, 64),
|
||||
num_groups=(8, 16, 32, 64),
|
||||
embed_dim=(128, 256, 512, 1024),
|
||||
num_heads=(4, 8, 16, 32),
|
||||
num_groups=(4, 8, 16, 32),
|
||||
window_size=12,
|
||||
drop_path_rate=0.1,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
activation_function="gelu",
|
||||
projection_dim=1024,
|
||||
visual_temporal_embedding={"type": "COSINE", "max_temporal_embeddings": 100},
|
||||
image_pos_embed={"type": "learned_abs_2d", "max_pos_embeddings": 50},
|
||||
max_temporal_embeddings=100,
|
||||
max_position_embeddings=50,
|
||||
image_feature_source=("spatial_avg_pool", "temporal_avg_pool"),
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
@ -115,7 +110,6 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
@ -129,8 +123,8 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.projection_dim = projection_dim
|
||||
self.visual_temporal_embedding = visual_temporal_embedding
|
||||
self.image_pos_embed = image_pos_embed
|
||||
self.max_temporal_embeddings = max_temporal_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.image_feature_source = image_feature_source
|
||||
self.initializer_range = initializer_range
|
||||
self.activation_function = activation_function
|
||||
@ -297,13 +291,6 @@ class Florence2Config(PretrainedConfig):
|
||||
Dictionary of configuration options used to initialize [`Florence2LanguageConfig`].
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Florence2VisionConfig`].
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the multimodal projection space.
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
|
||||
Example:
|
||||
|
||||
@ -336,31 +323,28 @@ class Florence2Config(PretrainedConfig):
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
vocab_size=51289,
|
||||
projection_dim=1024,
|
||||
is_encoder_decoder=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.projection_dim = projection_dim
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
if isinstance(text_config, dict):
|
||||
text_config = Florence2LanguageConfig(**text_config)
|
||||
elif text_config is None:
|
||||
text_config = Florence2LanguageConfig()
|
||||
logger.info("text_config is None. Initializing the Florence2LanguageConfig with default values.")
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = Florence2VisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
logger.info("vision_config is None. Initializing the Florence2VisionConfig with default values.")
|
||||
vision_config = Florence2VisionConfig()
|
||||
|
||||
self.vision_config = Florence2VisionConfig(**vision_config)
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
projection_dim=projection_dim,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
self.projection_dim = self.vision_config.projection_dim
|
||||
|
||||
|
||||
__all__ = ["Florence2Config", "Florence2LanguageConfig", "Florence2VisionConfig"]
|
||||
|
@ -562,7 +562,6 @@ class Florence2VisionBackbone(Florence2VisionPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.num_classes = config.num_classes
|
||||
self.embed_dim = config.embed_dim
|
||||
self.num_heads = config.num_heads
|
||||
self.num_groups = config.num_groups
|
||||
@ -2184,37 +2183,24 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self._build_image_projection_layers(config)
|
||||
self.vision_embedding_dim = config.vision_config.embed_dim[-1]
|
||||
self.vision_projection_dim = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.ones(self.vision_embedding_dim, self.vision_projection_dim))
|
||||
self.image_proj_norm = nn.LayerNorm(self.vision_projection_dim)
|
||||
|
||||
self.image_position_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=self.vision_embedding_dim, num_pos=config.vision_config.max_position_embeddings
|
||||
)
|
||||
|
||||
self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(
|
||||
embed_dim=self.vision_embedding_dim, max_seq_len=config.vision_config.max_temporal_embeddings
|
||||
)
|
||||
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
def _build_image_projection_layers(self, config: Florence2Config):
|
||||
image_dim_out = config.vision_config.embed_dim[-1]
|
||||
dim_projection = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.ones(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.vision_config.image_pos_embed
|
||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
||||
self.image_pos_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
self.image_feature_source = config.vision_config.image_feature_source
|
||||
|
||||
# temporal embedding
|
||||
visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
|
||||
if visual_temporal_embedding_config["type"] == "COSINE":
|
||||
self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(
|
||||
embed_dim=image_dim_out, max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
def get_encoder(self) -> Florence2LanguageEncoder:
|
||||
return self.language_model.get_encoder()
|
||||
|
||||
@ -2235,58 +2221,25 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
# update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _encode_image(self, pixel_values: torch.Tensor):
|
||||
if len(pixel_values.shape) == 4:
|
||||
batch_size, C, H, W = pixel_values.shape
|
||||
T = 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * T, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
if h * w != num_tokens:
|
||||
raise ValueError("only support square feature maps for now")
|
||||
x = x.view(batch_size * T, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, T * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
|
||||
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict["last_frame"] = x
|
||||
|
||||
new_x = []
|
||||
for _image_feature_source in self.image_feature_source:
|
||||
if _image_feature_source not in x_feat_dict:
|
||||
raise ValueError("invalid image feature source: {}".format(_image_feature_source))
|
||||
new_x.append(x_feat_dict[_image_feature_source])
|
||||
|
||||
x = torch.cat(new_x, dim=1)
|
||||
|
||||
x = x @ self.image_projection
|
||||
x = self.image_proj_norm(x)
|
||||
|
||||
return x
|
||||
def get_image_features(self, pixel_values: torch.Tensor):
|
||||
vision_features = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
position_features = vision_features + self.image_position_embed(vision_features)
|
||||
position_features = position_features.flatten(2).transpose(1, 2)
|
||||
temporal_features = self.visual_temporal_embed(position_features[:, :, 0])
|
||||
temporal_features = temporal_features.unsqueeze(1)
|
||||
visual_token_features = position_features + temporal_features
|
||||
visual_token_features = visual_token_features.unsqueeze(1)
|
||||
spatial_image_features = visual_token_features.mean(dim=2)
|
||||
temporal_image_features = visual_token_features.mean(dim=1)
|
||||
image_features = torch.cat([spatial_image_features, temporal_image_features], dim=1)
|
||||
image_features = image_features @ self.image_projection
|
||||
image_features = self.image_proj_norm(image_features)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
@ -2334,6 +2287,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
|
||||
r"""
|
||||
Args:
|
||||
@ -2379,7 +2333,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None:
|
||||
# (batch_size, num_image_tokens, hidden_size)
|
||||
image_features = self._encode_image(pixel_values)
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, attention_mask
|
||||
)
|
||||
@ -2404,6 +2358,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -2450,7 +2405,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self._encode_image(pixel_values)
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, attention_mask
|
||||
)
|
||||
|
@ -504,7 +504,6 @@ class Florence2VisionBackbone(Florence2VisionPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.num_classes = config.num_classes
|
||||
self.embed_dim = config.embed_dim
|
||||
self.num_heads = config.num_heads
|
||||
self.num_groups = config.num_groups
|
||||
@ -809,37 +808,24 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self._build_image_projection_layers(config)
|
||||
self.vision_embedding_dim = config.vision_config.embed_dim[-1]
|
||||
self.vision_projection_dim = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.ones(self.vision_embedding_dim, self.vision_projection_dim))
|
||||
self.image_proj_norm = nn.LayerNorm(self.vision_projection_dim)
|
||||
|
||||
self.image_position_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=self.vision_embedding_dim, num_pos=config.vision_config.max_position_embeddings
|
||||
)
|
||||
|
||||
self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(
|
||||
embed_dim=self.vision_embedding_dim, max_seq_len=config.vision_config.max_temporal_embeddings
|
||||
)
|
||||
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
def _build_image_projection_layers(self, config: Florence2Config):
|
||||
image_dim_out = config.vision_config.embed_dim[-1]
|
||||
dim_projection = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.ones(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.vision_config.image_pos_embed
|
||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
||||
self.image_pos_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
self.image_feature_source = config.vision_config.image_feature_source
|
||||
|
||||
# temporal embedding
|
||||
visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
|
||||
if visual_temporal_embedding_config["type"] == "COSINE":
|
||||
self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(
|
||||
embed_dim=image_dim_out, max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
def get_encoder(self) -> Florence2LanguageEncoder:
|
||||
return self.language_model.get_encoder()
|
||||
|
||||
@ -860,58 +846,25 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
# update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _encode_image(self, pixel_values: torch.Tensor):
|
||||
if len(pixel_values.shape) == 4:
|
||||
batch_size, C, H, W = pixel_values.shape
|
||||
T = 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * T, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
if h * w != num_tokens:
|
||||
raise ValueError("only support square feature maps for now")
|
||||
x = x.view(batch_size * T, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, T * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
|
||||
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict["last_frame"] = x
|
||||
|
||||
new_x = []
|
||||
for _image_feature_source in self.image_feature_source:
|
||||
if _image_feature_source not in x_feat_dict:
|
||||
raise ValueError("invalid image feature source: {}".format(_image_feature_source))
|
||||
new_x.append(x_feat_dict[_image_feature_source])
|
||||
|
||||
x = torch.cat(new_x, dim=1)
|
||||
|
||||
x = x @ self.image_projection
|
||||
x = self.image_proj_norm(x)
|
||||
|
||||
return x
|
||||
def get_image_features(self, pixel_values: torch.Tensor):
|
||||
vision_features = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
position_features = vision_features + self.image_position_embed(vision_features)
|
||||
position_features = position_features.flatten(2).transpose(1, 2)
|
||||
temporal_features = self.visual_temporal_embed(position_features[:, :, 0])
|
||||
temporal_features = temporal_features.unsqueeze(1)
|
||||
visual_token_features = position_features + temporal_features
|
||||
visual_token_features = visual_token_features.unsqueeze(1)
|
||||
spatial_image_features = visual_token_features.mean(dim=2)
|
||||
temporal_image_features = visual_token_features.mean(dim=1)
|
||||
image_features = torch.cat([spatial_image_features, temporal_image_features], dim=1)
|
||||
image_features = image_features @ self.image_projection
|
||||
image_features = self.image_proj_norm(image_features)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
@ -959,6 +912,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1004,7 +958,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None:
|
||||
# (batch_size, num_image_tokens, hidden_size)
|
||||
image_features = self._encode_image(pixel_values)
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, attention_mask
|
||||
)
|
||||
@ -1029,6 +983,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -1075,7 +1030,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self._encode_image(pixel_values)
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, attention_mask
|
||||
)
|
||||
|
@ -18,29 +18,29 @@ Processor class for FLORENCE2.
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ChannelDimension, ImageInput, is_valid_image
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...image_utils import ImageInput, is_valid_image
|
||||
from ...processing_utils import (
|
||||
MultiModalData,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
_validate_images_text_input_order,
|
||||
)
|
||||
from ...tokenization_utils_base import (
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_available, is_vision_available, logging
|
||||
from ..bart.tokenization_bart import BartTokenizer
|
||||
from ..bart.tokenization_bart_fast import BartTokenizerFast
|
||||
from ...utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -59,6 +59,13 @@ def _is_str_or_image(elem):
|
||||
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
||||
|
||||
|
||||
class Florence2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {"padding": False, "return_token_type_ids": False},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class Florence2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
|
||||
@ -69,12 +76,12 @@ class Florence2Processor(ProcessorMixin):
|
||||
Args:
|
||||
image_processor ([`CLIPImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`BartTokenizerFast`], *optional*):
|
||||
tokenizer ([`BartTokenizer`, `BartTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "CLIPImageProcessor"
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
|
||||
|
||||
def __init__(
|
||||
@ -183,59 +190,29 @@ class Florence2Processor(ProcessorMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
images: ImageInput = None,
|
||||
tokenize_newline_separately: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||
do_resize: Optional[bool] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[ChannelDimension] = "channels_first",
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
resample: Optional["PILImageResampling"] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
do_thumbnail: Optional[bool] = None,
|
||||
do_align_long_axis: Optional[bool] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Florence2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
tokenize_newline_separately (`bool`, defaults to `True`):
|
||||
Adds a separately tokenized '\n' at the end of the prompt.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
@ -244,82 +221,89 @@ class Florence2Processor(ProcessorMixin):
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
|
||||
is provided, the `input_ids` will also contain the suffix input ids.
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **labels** -- Labels compatible with training if `suffix` is not None
|
||||
"""
|
||||
|
||||
return_token_type_ids = False
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify at least one of `images` or `text`.")
|
||||
|
||||
if images is None:
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Florence2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
||||
else:
|
||||
raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
|
||||
|
||||
if text is None:
|
||||
logger.warning_once("You are using Florence-2 without a text prompt.")
|
||||
logger.warning_once("You are using Florence-2 without a text prefix.")
|
||||
text = ""
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
if isinstance(text, List) and isinstance(images, List):
|
||||
if len(images) < len(text):
|
||||
raise ValueError(
|
||||
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
|
||||
)
|
||||
if _is_str_or_image(text):
|
||||
text = [text]
|
||||
elif isinstance(text, list) and _is_str_or_image(text[0]):
|
||||
pass
|
||||
|
||||
pixel_values = self.image_processor(
|
||||
images,
|
||||
do_resize=do_resize,
|
||||
do_normalize=do_normalize,
|
||||
return_tensors=return_tensors,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
data_format=data_format,
|
||||
resample=resample,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
)["pixel_values"]
|
||||
prompt_strings = self._construct_prompts(text)
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=return_tensors)
|
||||
|
||||
if max_length is not None:
|
||||
max_length -= self.image_seq_length # max_length has to account for the image tokens
|
||||
return_data = {**text_inputs, "pixel_values": pixel_values}
|
||||
|
||||
text = self._construct_prompts(text)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
text,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
)
|
||||
|
||||
return_data = {**inputs, "pixel_values": pixel_values}
|
||||
|
||||
if return_token_type_ids:
|
||||
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
||||
return_data.update({"labels": labels})
|
||||
return BatchFeature(data=return_data)
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||
"""
|
||||
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||
|
||||
Args:
|
||||
image_sizes (List[List[str]], *optional*):
|
||||
The input sizes formatted as (height, width) per each image.
|
||||
Returns:
|
||||
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||
input modalities, along with other useful data.
|
||||
"""
|
||||
|
||||
vision_data = {}
|
||||
if image_sizes is not None:
|
||||
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
||||
num_image_patches = [1] * len(image_sizes)
|
||||
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||
return MultiModalData(**vision_data)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
@ -515,6 +499,61 @@ class CoordinatesQuantizer(object):
|
||||
return dequantized_coordinates
|
||||
|
||||
|
||||
class PointsQuantizer(object):
|
||||
"""
|
||||
Quantize points (Nx2)
|
||||
"""
|
||||
|
||||
def __init__(self, mode, bins):
|
||||
self.mode = mode
|
||||
self.bins = bins
|
||||
|
||||
def quantize(self, coordinates, size):
|
||||
bins_w, bins_h = self.bins # Quantization bins.
|
||||
size_w, size_h = size # Original image size.
|
||||
size_per_bin_w = size_w / bins_w
|
||||
size_per_bin_h = size_h / bins_h
|
||||
assert coordinates.shape[-1] == 2, "coordinates should be shape (N, 2)"
|
||||
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
||||
|
||||
if self.mode == "floor":
|
||||
quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
||||
quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
||||
|
||||
elif self.mode == "round":
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
raise ValueError("Incorrect quantization type.")
|
||||
|
||||
quantized_coordinates = torch.cat((quantized_x, quantized_y), dim=-1).int()
|
||||
|
||||
return quantized_coordinates
|
||||
|
||||
def dequantize(self, coordinates, size):
|
||||
bins_w, bins_h = self.bins # Quantization bins.
|
||||
size_w, size_h = size # Original image size.
|
||||
size_per_bin_w = size_w / bins_w
|
||||
size_per_bin_h = size_h / bins_h
|
||||
assert coordinates.shape[-1] == 2, "coordinates should be shape (N, 2)"
|
||||
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
||||
|
||||
if self.mode == "floor":
|
||||
# Add 0.5 to use the center position of the bin as the coordinate.
|
||||
dequantized_x = (x + 0.5) * size_per_bin_w
|
||||
dequantized_y = (y + 0.5) * size_per_bin_h
|
||||
|
||||
elif self.mode == "round":
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
raise ValueError("Incorrect quantization type.")
|
||||
|
||||
dequantized_coordinates = torch.cat((dequantized_x, dequantized_y), dim=-1)
|
||||
|
||||
return dequantized_coordinates
|
||||
|
||||
|
||||
class Florence2PostProcessor(object):
|
||||
r"""
|
||||
Florence-2 post process for converting text prediction to various tasks results.
|
||||
@ -522,22 +561,6 @@ class Florence2PostProcessor(object):
|
||||
Args:
|
||||
config: A dict of configs.
|
||||
tokenizer: A tokenizer for decoding text to spans.
|
||||
sample config:
|
||||
UNIFIED_POST_PROCESS:
|
||||
# commom configs
|
||||
NUM_BBOX_HEIGHT_BINS: 1000
|
||||
NUM_BBOX_WIDTH_BINS: 1000
|
||||
COORDINATES_HEIGHT_BINS: 1000
|
||||
COORDINATES_WIDTH_BINS: 1000
|
||||
# task specific configs, override the common configs
|
||||
PRASE_TASKS:
|
||||
- TASK_NAME: 'video_dense_caption'
|
||||
PATTERN: r'<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
|
||||
SCORE_MODE: 'avg_cat_name_scores'
|
||||
NUM_BINS: 100
|
||||
- TASK_NAME: 'od'
|
||||
PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
|
||||
SCORE_MODE: 'avg_cat_name_scores'
|
||||
|
||||
Returns:
|
||||
parsed_dict (dict): A dict of parsed results.
|
||||
@ -746,23 +769,17 @@ class Florence2PostProcessor(object):
|
||||
def decode_with_spans(self, tokenizer, token_ids):
|
||||
filtered_tokens = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
|
||||
assert len(filtered_tokens) == len(token_ids)
|
||||
sub_texts = []
|
||||
for token in filtered_tokens:
|
||||
if token in self.all_special_tokens:
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
||||
sub_text = tokenizer.convert_tokens_to_string([token])
|
||||
else:
|
||||
raise ValueError(f"type {type(tokenizer)} not supported")
|
||||
sub_texts.append(sub_text)
|
||||
|
||||
text = ""
|
||||
spans = []
|
||||
for sub_text in sub_texts:
|
||||
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
||||
for token in filtered_tokens:
|
||||
if token in self.all_special_tokens:
|
||||
sub_text = token
|
||||
else:
|
||||
sub_text = tokenizer.convert_tokens_to_string([token])
|
||||
span = (len(text), len(text) + len(sub_text))
|
||||
text += sub_text
|
||||
spans.append(span)
|
||||
|
||||
return text, spans
|
||||
|
||||
def parse_od_from_text_and_spans(self, text, pattern, image_size, phrase_centric=False):
|
||||
@ -807,7 +824,7 @@ class Florence2PostProcessor(object):
|
||||
quad_box = ocr_line[1:]
|
||||
quad_box = [int(i) for i in quad_box]
|
||||
quad_box = (
|
||||
self.coordinates_quantizer.dequantize(torch.tensor(np.array(quad_box).reshape(-1, 2)), size=image_size)
|
||||
self.coordinates_quantizer.dequantize(torch.tensor(quad_box).reshape(-1, 2), size=image_size)
|
||||
.reshape(-1)
|
||||
.tolist()
|
||||
)
|
||||
@ -860,30 +877,25 @@ class Florence2PostProcessor(object):
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
# Prepare instance.
|
||||
instance = {}
|
||||
|
||||
# parse phrase, get string
|
||||
phrase = re.search(pattern, phrase_text_strip)
|
||||
if phrase is None:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
phrase = phrase.group().strip()
|
||||
if phrase in self.black_list_of_phrase_grounding:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
# parse bboxes by box_pattern
|
||||
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
||||
if len(bboxes_parsed) == 0:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
phrase = phrase.group()
|
||||
# remove leading and trailing spaces
|
||||
phrase = phrase.strip()
|
||||
|
||||
if phrase in self.black_list_of_phrase_grounding:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
# a list of list
|
||||
# Prepare instance.
|
||||
instance = {}
|
||||
bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
|
||||
instance["bbox"] = self.box_quantizer.dequantize(boxes=torch.tensor(bbox_bins), size=image_size).tolist()
|
||||
|
||||
@ -1044,12 +1056,7 @@ class Florence2PostProcessor(object):
|
||||
phrase = re.search(phrase_string_pattern, phrase_text_strip)
|
||||
if phrase is None:
|
||||
continue
|
||||
phrase = phrase.group()
|
||||
# remove leading and trailing spaces
|
||||
phrase = phrase.strip()
|
||||
|
||||
# parse bboxes by box_pattern
|
||||
|
||||
phrase = phrase.group().strip()
|
||||
# split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
|
||||
if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
|
||||
polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
|
||||
|
@ -59,7 +59,6 @@ class Florence2VisionText2TextModelTester:
|
||||
seq_length=13,
|
||||
encoder_seq_length=18,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
max_position_embeddings=64,
|
||||
encoder_layers=1,
|
||||
@ -68,7 +67,7 @@ class Florence2VisionText2TextModelTester:
|
||||
decoder_ffn_dim=8,
|
||||
num_attention_heads=1,
|
||||
d_model=8,
|
||||
hidden_act="gelu",
|
||||
activation_function="gelu",
|
||||
dropout=0.1,
|
||||
eos_token_id=2,
|
||||
bos_token_id=0,
|
||||
@ -78,7 +77,7 @@ class Florence2VisionText2TextModelTester:
|
||||
patch_stride=[4],
|
||||
patch_padding=[3],
|
||||
patch_prenorm=[False],
|
||||
dim_embed=[8],
|
||||
embed_dim=[8],
|
||||
num_heads=[1],
|
||||
num_groups=[1],
|
||||
window_size=12,
|
||||
@ -104,7 +103,7 @@ class Florence2VisionText2TextModelTester:
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_model = d_model
|
||||
self.activation_function = hidden_act
|
||||
self.activation_function = activation_function
|
||||
self.dropout = dropout
|
||||
self.eos_token_id = eos_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
@ -117,7 +116,7 @@ class Florence2VisionText2TextModelTester:
|
||||
self.patch_stride = patch_stride
|
||||
self.patch_padding = patch_padding
|
||||
self.patch_prenorm = patch_prenorm
|
||||
self.dim_embed = dim_embed
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_groups = num_groups
|
||||
self.window_size = window_size
|
||||
@ -150,14 +149,15 @@ class Florence2VisionText2TextModelTester:
|
||||
patch_stride=self.patch_stride,
|
||||
patch_padding=self.patch_padding,
|
||||
patch_prenorm=self.patch_prenorm,
|
||||
dim_embed=self.dim_embed,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_groups=self.num_groups,
|
||||
window_size=self.window_size,
|
||||
activation_function=self.activation_function,
|
||||
projection_dim=self.projection_dim,
|
||||
)
|
||||
|
||||
return Florence2Config.from_text_vision_configs(text_config=text_config, vision_config=vision_config)
|
||||
return Florence2Config(text_config=text_config, vision_config=vision_config)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
|
@ -15,7 +15,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, BartTokenizerFast, Florence2Processor
|
||||
from transformers import AutoProcessor, BartTokenizerFast, Florence2Processor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
@ -55,9 +55,3 @@ class Florence2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
def test_processor_to_json_string(self):
|
||||
pass
|
||||
|
||||
def test_can_load_various_tokenizers(self):
|
||||
for checkpoint in ["microsoft/Florence-2-base"]:
|
||||
processor = Florence2Processor.from_pretrained(checkpoint)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
|
||||
|
Loading…
Reference in New Issue
Block a user