Add mask2former fp16 support (#25093)

* Add mask2former fp16 support

* Clear consistency/quality issues

* Fix consistency/quality (2)

* Add integration test for mask2former (fp16 case)

* Fix code quality

* Add integration test for maskformer (fp16 case)

* Add integration test for oneformer (fp16 case)

* Remove slow decorator from fp16 tests

* Fix lint

* Remove usage of full inference and value checks for fp16

* Temporarily comment slow for {mask, mask2, one}former

* Add fp16 support to oneformer

* Revert "Temporarily comment slow for {mask, mask2, one}former"

This reverts commit e5371edabd.

* Remove dtype conversion noop
This commit is contained in:
Pedro Lira 2023-08-07 16:07:29 -03:00 committed by GitHub
parent 5ee9693a1c
commit 080a97119c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 97 additions and 37 deletions

View File

@ -864,15 +864,15 @@ class Mask2FormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
@ -1104,8 +1104,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
@ -1267,14 +1267,14 @@ class Mask2FormerPixelDecoder(nn.Module):
self.lateral_convolutions = lateral_convs[::-1]
self.output_convolutions = output_convs[::-1]
def get_valid_ratio(self, mask):
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
@ -1295,8 +1295,8 @@ class Mask2FormerPixelDecoder(nn.Module):
input_embeds = []
position_embeddings = []
for level, x in enumerate(features[::-1][: self.num_feature_levels]):
input_embeds.append(self.input_projections[level](x.float()))
position_embeddings.append(self.position_embedding(x.float()))
input_embeds.append(self.input_projections[level](x))
position_embeddings.append(self.position_embedding(x))
masks = [
torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds
@ -1313,7 +1313,7 @@ class Mask2FormerPixelDecoder(nn.Module):
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1)
valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1)
# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
if encoder_outputs is None:
@ -1351,7 +1351,7 @@ class Mask2FormerPixelDecoder(nn.Module):
for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]):
lateral_conv = self.lateral_convolutions[idx]
output_conv = self.output_convolutions[idx]
current_fpn = lateral_conv(feature.float())
current_fpn = lateral_conv(feature)
# Following FPN implementation, we use nearest upsampling here
out = current_fpn + nn.functional.interpolate(

View File

@ -1290,15 +1290,15 @@ class MaskFormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t

View File

@ -1179,8 +1179,8 @@ class OneFormerPixelDecoderEncoderOnly(nn.Module):
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
@ -1352,14 +1352,14 @@ class OneFormerPixelDecoder(nn.Module):
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
def get_valid_ratio(self, mask):
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
@ -1380,9 +1380,8 @@ class OneFormerPixelDecoder(nn.Module):
sources = []
position_embeddings_list = []
for level, source in enumerate(features[::-1][: self.num_feature_levels]):
feats = source.float()
sources.append(self.input_projections[level](feats))
position_embeddings_list.append(self.position_embedding(feats))
sources.append(self.input_projections[level](source))
position_embeddings_list.append(self.position_embedding(source))
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources]
@ -1407,8 +1406,7 @@ class OneFormerPixelDecoder(nn.Module):
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
valid_ratios = valid_ratios.float()
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
# Also provide spatial_shapes, level_start_index and valid_ratios
@ -1445,7 +1443,6 @@ class OneFormerPixelDecoder(nn.Module):
# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]):
feats = feats.float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(feats)
@ -2396,15 +2393,15 @@ class OneFormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
@ -2744,7 +2741,7 @@ class OneFormerTaskModel(nn.Module):
)
def forward(self, inputs: Tensor) -> Tensor:
task_tokens = self.task_mlp(inputs.float())
task_tokens = self.task_mlp(inputs)
return task_tokens
@ -2980,7 +2977,7 @@ class OneFormerModel(OneFormerPreTrainedModel):
multi_scale_features = pixel_level_module_output.decoder_features
mask_features = pixel_level_module_output.decoder_last_feature
task_token = self.task_encoder(task_inputs)
task_token = self.task_encoder(task_inputs.to(self.dtype))
if self.is_training:
text_queries = self.text_mapper(text_inputs)

View File

@ -21,7 +21,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor

View File

@ -22,7 +22,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -509,6 +516,20 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device, dtype=torch.float16)
.eval()
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")

View File

@ -22,7 +22,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
processor = self.default_processor
image = prepare_img()
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self):
dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
processor = self.default_processor