From 080a97119c0dabfd0fb5c3e26a872ad2958e4f77 Mon Sep 17 00:00:00 2001 From: Pedro Lira Date: Mon, 7 Aug 2023 16:07:29 -0300 Subject: [PATCH] Add mask2former fp16 support (#25093) * Add mask2former fp16 support * Clear consistency/quality issues * Fix consistency/quality (2) * Add integration test for mask2former (fp16 case) * Fix code quality * Add integration test for maskformer (fp16 case) * Add integration test for oneformer (fp16 case) * Remove slow decorator from fp16 tests * Fix lint * Remove usage of full inference and value checks for fp16 * Temporarily comment slow for {mask, mask2, one}former * Add fp16 support to oneformer * Revert "Temporarily comment slow for {mask, mask2, one}former" This reverts commit e5371edabd301cf56079def0421a0a87df307cb0. * Remove dtype conversion noop --- .../mask2former/modeling_mask2former.py | 26 ++++++++-------- .../models/maskformer/modeling_maskformer.py | 8 ++--- .../models/oneformer/modeling_oneformer.py | 31 +++++++++---------- .../mask2former/test_modeling_mask2former.py | 23 +++++++++++++- .../maskformer/test_modeling_maskformer.py | 23 +++++++++++++- .../oneformer/test_modeling_oneformer.py | 23 +++++++++++++- 6 files changed, 97 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index f814e1a0817..96e03b84188 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -864,15 +864,15 @@ class Mask2FormerSinePositionEmbedding(nn.Module): def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) + not_mask = (~mask).to(x.dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -1104,8 +1104,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module): reference_points_list = [] for lvl, (height, width) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( - torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), - torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), indexing="ij", ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) @@ -1267,14 +1267,14 @@ class Mask2FormerPixelDecoder(nn.Module): self.lateral_convolutions = lateral_convs[::-1] self.output_convolutions = output_convs[::-1] - def get_valid_ratio(self, mask): + def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(~mask[:, :, 0], 1) valid_width = torch.sum(~mask[:, 0, :], 1) - valid_ratio_heigth = valid_height.float() / height - valid_ratio_width = valid_width.float() / width + valid_ratio_heigth = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) return valid_ratio @@ -1295,8 +1295,8 @@ class Mask2FormerPixelDecoder(nn.Module): input_embeds = [] position_embeddings = [] for level, x in enumerate(features[::-1][: self.num_feature_levels]): - input_embeds.append(self.input_projections[level](x.float())) - position_embeddings.append(self.position_embedding(x.float())) + input_embeds.append(self.input_projections[level](x)) + position_embeddings.append(self.position_embedding(x)) masks = [ torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds @@ -1313,7 +1313,7 @@ class Mask2FormerPixelDecoder(nn.Module): level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1) + valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder if encoder_outputs is None: @@ -1351,7 +1351,7 @@ class Mask2FormerPixelDecoder(nn.Module): for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): lateral_conv = self.lateral_convolutions[idx] output_conv = self.output_convolutions[idx] - current_fpn = lateral_conv(feature.float()) + current_fpn = lateral_conv(feature) # Following FPN implementation, we use nearest upsampling here out = current_fpn + nn.functional.interpolate( diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 8a9a950cbd9..4f9500bd127 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1290,15 +1290,15 @@ class MaskFormerSinePositionEmbedding(nn.Module): def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) + not_mask = (~mask).to(x.dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 06d64709540..0b15e28d641 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -1179,8 +1179,8 @@ class OneFormerPixelDecoderEncoderOnly(nn.Module): reference_points_list = [] for lvl, (height, width) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( - torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), - torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) @@ -1352,14 +1352,14 @@ class OneFormerPixelDecoder(nn.Module): self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] - def get_valid_ratio(self, mask): + def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(~mask[:, :, 0], 1) valid_width = torch.sum(~mask[:, 0, :], 1) - valid_ratio_heigth = valid_height.float() / height - valid_ratio_width = valid_width.float() / width + valid_ratio_heigth = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) return valid_ratio @@ -1380,9 +1380,8 @@ class OneFormerPixelDecoder(nn.Module): sources = [] position_embeddings_list = [] for level, source in enumerate(features[::-1][: self.num_feature_levels]): - feats = source.float() - sources.append(self.input_projections[level](feats)) - position_embeddings_list.append(self.position_embedding(feats)) + sources.append(self.input_projections[level](source)) + position_embeddings_list.append(self.position_embedding(source)) masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources] @@ -1407,8 +1406,7 @@ class OneFormerPixelDecoder(nn.Module): lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) - valid_ratios = valid_ratios.float() + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder # Also provide spatial_shapes, level_start_index and valid_ratios @@ -1445,7 +1443,6 @@ class OneFormerPixelDecoder(nn.Module): # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]): - feats = feats.float() lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(feats) @@ -2396,15 +2393,15 @@ class OneFormerSinePositionEmbedding(nn.Module): def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) + not_mask = (~mask).to(x.dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -2744,7 +2741,7 @@ class OneFormerTaskModel(nn.Module): ) def forward(self, inputs: Tensor) -> Tensor: - task_tokens = self.task_mlp(inputs.float()) + task_tokens = self.task_mlp(inputs) return task_tokens @@ -2980,7 +2977,7 @@ class OneFormerModel(OneFormerPreTrainedModel): multi_scale_features = pixel_level_module_output.decoder_features mask_features = pixel_level_module_output.decoder_last_feature - task_token = self.task_encoder(task_inputs) + task_token = self.task_encoder(task_inputs.to(self.dtype)) if self.is_training: text_queries = self.text_mapper(text_inputs) diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index 898f2199922..c05901a9dcc 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -21,7 +21,14 @@ import numpy as np from tests.test_modeling_common import floats_tensor from transformers import Mask2FormerConfig, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + @require_torch_gpu + def test_inference_fp16(self): + model = ( + Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints) + .to(torch_device, dtype=torch.float16) + .eval() + ) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16) + + with torch.no_grad(): + _ = model(**inputs) + def test_with_segmentation_maps_and_loss(self): model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() image_processor = self.default_image_processor diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index c37127cea73..32d5e03b031 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -22,7 +22,14 @@ import numpy as np from tests.test_modeling_common import floats_tensor from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -509,6 +516,20 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + @require_torch_gpu + def test_inference_fp16(self): + model = ( + MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff") + .to(torch_device, dtype=torch.float16) + .eval() + ) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16) + + with torch.no_grad(): + _ = model(**inputs) + def test_with_segmentation_maps_and_loss(self): model = ( MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco") diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py index ef4a45021ac..222f1ce66b0 100644 --- a/tests/models/oneformer/test_modeling_oneformer.py +++ b/tests/models/oneformer/test_modeling_oneformer.py @@ -22,7 +22,14 @@ import numpy as np from tests.test_modeling_common import floats_tensor from transformers import OneFormerConfig, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + @require_torch_gpu + def test_inference_fp16(self): + model = ( + OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints) + .to(torch_device, dtype=torch.float16) + .eval() + ) + processor = self.default_processor + image = prepare_img() + inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device, dtype=torch.float16) + + with torch.no_grad(): + _ = model(**inputs) + def test_with_segmentation_maps_and_loss(self): dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints) processor = self.default_processor