mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix init test
This commit is contained in:
parent
88e0d6e44c
commit
505a878da9
@ -2584,7 +2584,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
def _build_image_projection_layers(self, config: Florence2Config):
|
||||
image_dim_out = config.vision_config.dim_embed[-1]
|
||||
dim_projection = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
|
||||
self.image_projection = nn.Parameter(torch.zeros(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.vision_config.image_pos_embed
|
||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
||||
|
@ -568,8 +568,8 @@ class Florence2VisionBlock(nn.Module):
|
||||
class Florence2VisionPreTrainedModel(PreTrainedModel):
|
||||
config_class = Florence2VisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_supports_sdpa = False
|
||||
_supports_flash_attn_2 = False
|
||||
# _supports_sdpa = False
|
||||
# _supports_flash_attn_2 = False
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, nn.BatchNorm2d]) -> None:
|
||||
"""Initialize the weights"""
|
||||
@ -1020,7 +1020,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
def _build_image_projection_layers(self, config: Florence2Config):
|
||||
image_dim_out = config.vision_config.dim_embed[-1]
|
||||
dim_projection = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
|
||||
self.image_projection = nn.Parameter(torch.zeros(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.vision_config.image_pos_embed
|
||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
||||
|
@ -382,7 +382,7 @@ class Florence2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_batched_generation(self):
|
||||
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-base")
|
||||
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-base").to(torch_device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user