mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[SegGPT] Fix loss calculation (#30421)
* Fixed main train issues * Added loss test * Update src/transformers/models/seggpt/modeling_seggpt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Added missing labels arg in SegGptModel forward * Fixed typo * Added slow test to test loss calculation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
37fa1f654f
commit
d26c14139c
@ -753,11 +753,15 @@ class SegGptModel(SegGptPreTrainedModel):
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
feature_ensemble: Optional[bool] = None,
|
||||
embedding_type: Optional[str] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SegGptEncoderOutput]:
|
||||
r"""
|
||||
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
|
||||
Ground truth mask for input images.
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
@ -799,10 +803,21 @@ class SegGptModel(SegGptPreTrainedModel):
|
||||
|
||||
# Prepare inputs
|
||||
pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
|
||||
prompt_pixel_values = torch.cat((prompt_masks, prompt_masks), dim=2)
|
||||
prompt_pixel_values = (
|
||||
torch.cat((prompt_masks, prompt_masks), dim=2)
|
||||
if labels is None
|
||||
else torch.cat((prompt_masks, labels), dim=2)
|
||||
)
|
||||
|
||||
if bool_masked_pos is None and labels is not None:
|
||||
logger.warning_once(
|
||||
"Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
|
||||
)
|
||||
|
||||
# We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
|
||||
# of the prompt pixels that will be destinated to the prediction as they don't add any information.
|
||||
# of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
|
||||
# This is only the case for inference. In training, the model concat of prompt mask and label is masked
|
||||
# and reconstructed together (In-Context Painting).
|
||||
if bool_masked_pos is None:
|
||||
num_patches = self.embeddings.patch_embeddings.num_patches
|
||||
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
|
||||
@ -840,7 +855,9 @@ def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> tor
|
||||
batch_size = tensor.shape[0]
|
||||
patch_size = int((tensor.shape[-1] / 3) ** 0.5)
|
||||
if patch_height * patch_width != tensor.shape[1]:
|
||||
raise ValueError(f"Number of patches {tensor.shape[1]} does not match patch height and width.")
|
||||
raise ValueError(
|
||||
f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
|
||||
)
|
||||
|
||||
tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
|
||||
tensor = tensor.permute(0, 5, 1, 3, 2, 4)
|
||||
@ -857,8 +874,7 @@ class SegGptLoss(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
prompt_pixel_values: torch.FloatTensor,
|
||||
prompt_masks: torch.FloatTensor,
|
||||
pred_masks: torch.FloatTensor,
|
||||
labels: torch.FloatTensor,
|
||||
bool_masked_pos: torch.BoolTensor,
|
||||
@ -866,11 +882,8 @@ class SegGptLoss(nn.Module):
|
||||
"""Computes the L1 loss between the predicted masks and the ground truth masks.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
|
||||
Concatenated pixel values from prompt and input images.
|
||||
|
||||
prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
|
||||
Concatenated pixel values from mask prompt.
|
||||
prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values from mask prompt.
|
||||
|
||||
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
|
||||
Predicted masks.
|
||||
@ -884,12 +897,12 @@ class SegGptLoss(nn.Module):
|
||||
Returns:
|
||||
`torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
|
||||
"""
|
||||
ground_truth = torch.cat((prompt_masks, labels), dim=2)
|
||||
|
||||
mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
|
||||
mask = unpatchify(mask, pixel_values.shape[1] // self.patch_size, pixel_values.shape[2] // self.patch_size)
|
||||
# Changing dummy mask in prompt_pixel_values to labels values
|
||||
prompt_pixel_values = prompt_pixel_values.clone()
|
||||
prompt_pixel_values[:, :, prompt_pixel_values.shape[2] // 2 :, :] = labels
|
||||
loss = F.smooth_l1_loss(pred_masks, prompt_pixel_values, reduction="none", beta=self.beta)
|
||||
mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)
|
||||
|
||||
loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
|
||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||
|
||||
return loss
|
||||
@ -976,6 +989,7 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
feature_ensemble=feature_ensemble,
|
||||
embedding_type=embedding_type,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@ -988,7 +1002,7 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fn = SegGptLoss(self.config)
|
||||
loss = loss_fn(pixel_values, prompt_pixel_values, pred_masks, labels, bool_masked_pos)
|
||||
loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)
|
||||
|
||||
if not return_dict:
|
||||
output = (pred_masks,)
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
@ -39,6 +40,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import SegGptForImageSegmentation, SegGptModel
|
||||
from transformers.models.seggpt.modeling_seggpt import SegGptLoss
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -298,6 +300,22 @@ class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model_row_output[key] = model_row_output[key][1:]
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
def test_seggpt_loss(self):
|
||||
torch.manual_seed(100)
|
||||
config = self.model_tester.get_config()
|
||||
|
||||
prompt_masks = torch.rand(1, config.num_channels, config.image_size, config.image_size)
|
||||
label = torch.rand(1, config.num_channels, config.image_size, config.image_size)
|
||||
pred_masks = torch.rand(1, config.num_channels, config.image_size * 2, config.image_size)
|
||||
# seq_len x 2 because the loss concatenates prompt_masks and labels as pred_masks is concatenated
|
||||
bool_masked_pos = torch.rand(1, self.model_tester.seq_length * 2) > 0.5
|
||||
|
||||
loss = SegGptLoss(config)
|
||||
loss_value = loss(prompt_masks, pred_masks, label, bool_masked_pos)
|
||||
expected_loss_value = torch.tensor(0.3340)
|
||||
|
||||
self.assertTrue(torch.allclose(loss_value, expected_loss_value, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "BAAI/seggpt-vit-large"
|
||||
@ -312,6 +330,20 @@ def prepare_img():
|
||||
return images, masks
|
||||
|
||||
|
||||
def prepare_bool_masked_pos(config: SegGptConfig):
|
||||
num_patches = math.prod([i // config.patch_size for i in config.image_size])
|
||||
mask_ratio = 0.75
|
||||
torch.manual_seed(2)
|
||||
num_masked_patches = int(num_patches * mask_ratio)
|
||||
shuffle_idx = torch.randperm(num_patches)
|
||||
bool_masked_pos = torch.FloatTensor([0] * (num_patches - num_masked_patches) + [1] * num_masked_patches)[
|
||||
shuffle_idx
|
||||
]
|
||||
bool_masked_pos = bool_masked_pos.unsqueeze(0).bool()
|
||||
|
||||
return bool_masked_pos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class SegGptModelIntegrationTest(unittest.TestCase):
|
||||
@ -390,3 +422,30 @@ class SegGptModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(outputs.pred_masks.shape, expected_shape)
|
||||
self.assertTrue(torch.allclose(outputs.pred_masks[0, :, 448:451, :3], expected_slice, atol=4e-4))
|
||||
|
||||
@slow
|
||||
def test_one_shot_with_label(self):
|
||||
model = SegGptForImageSegmentation.from_pretrained("BAAI/seggpt-vit-large").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
|
||||
images, masks = prepare_img()
|
||||
|
||||
input_image = images[1]
|
||||
label = masks[1]
|
||||
prompt_image = images[0]
|
||||
prompt_mask = masks[0]
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_image, prompt_masks=prompt_mask, prompt_images=prompt_image, return_tensors="pt"
|
||||
).to(torch_device)
|
||||
|
||||
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt")["prompt_masks"].to(torch_device)
|
||||
|
||||
bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, labels=labels, bool_masked_pos=bool_masked_pos)
|
||||
|
||||
expected_loss = torch.tensor(0.0074).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user