mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add mask2former fp16 support (#25093)
* Add mask2former fp16 support
* Clear consistency/quality issues
* Fix consistency/quality (2)
* Add integration test for mask2former (fp16 case)
* Fix code quality
* Add integration test for maskformer (fp16 case)
* Add integration test for oneformer (fp16 case)
* Remove slow decorator from fp16 tests
* Fix lint
* Remove usage of full inference and value checks for fp16
* Temporarily comment slow for {mask, mask2, one}former
* Add fp16 support to oneformer
* Revert "Temporarily comment slow for {mask, mask2, one}former"
This reverts commit e5371edabd
.
* Remove dtype conversion noop
This commit is contained in:
parent
5ee9693a1c
commit
080a97119c
@ -864,15 +864,15 @@ class Mask2FormerSinePositionEmbedding(nn.Module):
|
||||
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
not_mask = (~mask).to(x.dtype)
|
||||
y_embed = not_mask.cumsum(1)
|
||||
x_embed = not_mask.cumsum(2)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@ -1104,8 +1104,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
||||
reference_points_list = []
|
||||
for lvl, (height, width) in enumerate(spatial_shapes):
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
|
||||
@ -1267,14 +1267,14 @@ class Mask2FormerPixelDecoder(nn.Module):
|
||||
self.lateral_convolutions = lateral_convs[::-1]
|
||||
self.output_convolutions = output_convs[::-1]
|
||||
|
||||
def get_valid_ratio(self, mask):
|
||||
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||
"""Get the valid ratio of all feature maps."""
|
||||
|
||||
_, height, width = mask.shape
|
||||
valid_height = torch.sum(~mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(~mask[:, 0, :], 1)
|
||||
valid_ratio_heigth = valid_height.float() / height
|
||||
valid_ratio_width = valid_width.float() / width
|
||||
valid_ratio_heigth = valid_height.to(dtype) / height
|
||||
valid_ratio_width = valid_width.to(dtype) / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
||||
return valid_ratio
|
||||
|
||||
@ -1295,8 +1295,8 @@ class Mask2FormerPixelDecoder(nn.Module):
|
||||
input_embeds = []
|
||||
position_embeddings = []
|
||||
for level, x in enumerate(features[::-1][: self.num_feature_levels]):
|
||||
input_embeds.append(self.input_projections[level](x.float()))
|
||||
position_embeddings.append(self.position_embedding(x.float()))
|
||||
input_embeds.append(self.input_projections[level](x))
|
||||
position_embeddings.append(self.position_embedding(x))
|
||||
|
||||
masks = [
|
||||
torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds
|
||||
@ -1313,7 +1313,7 @@ class Mask2FormerPixelDecoder(nn.Module):
|
||||
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)
|
||||
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1)
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1)
|
||||
|
||||
# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
|
||||
if encoder_outputs is None:
|
||||
@ -1351,7 +1351,7 @@ class Mask2FormerPixelDecoder(nn.Module):
|
||||
for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]):
|
||||
lateral_conv = self.lateral_convolutions[idx]
|
||||
output_conv = self.output_convolutions[idx]
|
||||
current_fpn = lateral_conv(feature.float())
|
||||
current_fpn = lateral_conv(feature)
|
||||
|
||||
# Following FPN implementation, we use nearest upsampling here
|
||||
out = current_fpn + nn.functional.interpolate(
|
||||
|
@ -1290,15 +1290,15 @@ class MaskFormerSinePositionEmbedding(nn.Module):
|
||||
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
not_mask = (~mask).to(x.dtype)
|
||||
y_embed = not_mask.cumsum(1)
|
||||
x_embed = not_mask.cumsum(2)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
|
@ -1179,8 +1179,8 @@ class OneFormerPixelDecoderEncoderOnly(nn.Module):
|
||||
reference_points_list = []
|
||||
for lvl, (height, width) in enumerate(spatial_shapes):
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
||||
)
|
||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
|
||||
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
|
||||
@ -1352,14 +1352,14 @@ class OneFormerPixelDecoder(nn.Module):
|
||||
self.lateral_convs = lateral_convs[::-1]
|
||||
self.output_convs = output_convs[::-1]
|
||||
|
||||
def get_valid_ratio(self, mask):
|
||||
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||
"""Get the valid ratio of all feature maps."""
|
||||
|
||||
_, height, width = mask.shape
|
||||
valid_height = torch.sum(~mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(~mask[:, 0, :], 1)
|
||||
valid_ratio_heigth = valid_height.float() / height
|
||||
valid_ratio_width = valid_width.float() / width
|
||||
valid_ratio_heigth = valid_height.to(dtype) / height
|
||||
valid_ratio_width = valid_width.to(dtype) / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
||||
return valid_ratio
|
||||
|
||||
@ -1380,9 +1380,8 @@ class OneFormerPixelDecoder(nn.Module):
|
||||
sources = []
|
||||
position_embeddings_list = []
|
||||
for level, source in enumerate(features[::-1][: self.num_feature_levels]):
|
||||
feats = source.float()
|
||||
sources.append(self.input_projections[level](feats))
|
||||
position_embeddings_list.append(self.position_embedding(feats))
|
||||
sources.append(self.input_projections[level](source))
|
||||
position_embeddings_list.append(self.position_embedding(source))
|
||||
|
||||
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources]
|
||||
|
||||
@ -1407,8 +1406,7 @@ class OneFormerPixelDecoder(nn.Module):
|
||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
||||
valid_ratios = valid_ratios.float()
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||
|
||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||
# Also provide spatial_shapes, level_start_index and valid_ratios
|
||||
@ -1445,7 +1443,6 @@ class OneFormerPixelDecoder(nn.Module):
|
||||
# append `out` with extra FPN levels
|
||||
# Reverse feature maps into top-down order (from low to high resolution)
|
||||
for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]):
|
||||
feats = feats.float()
|
||||
lateral_conv = self.lateral_convs[idx]
|
||||
output_conv = self.output_convs[idx]
|
||||
cur_fpn = lateral_conv(feats)
|
||||
@ -2396,15 +2393,15 @@ class OneFormerSinePositionEmbedding(nn.Module):
|
||||
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
not_mask = (~mask).to(x.dtype)
|
||||
y_embed = not_mask.cumsum(1)
|
||||
x_embed = not_mask.cumsum(2)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@ -2744,7 +2741,7 @@ class OneFormerTaskModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
task_tokens = self.task_mlp(inputs.float())
|
||||
task_tokens = self.task_mlp(inputs)
|
||||
return task_tokens
|
||||
|
||||
|
||||
@ -2980,7 +2977,7 @@ class OneFormerModel(OneFormerPreTrainedModel):
|
||||
multi_scale_features = pixel_level_module_output.decoder_features
|
||||
mask_features = pixel_level_module_output.decoder_last_feature
|
||||
|
||||
task_token = self.task_encoder(task_inputs)
|
||||
task_token = self.task_encoder(task_inputs.to(self.dtype))
|
||||
|
||||
if self.is_training:
|
||||
text_queries = self.text_mapper(text_inputs)
|
||||
|
@ -21,7 +21,14 @@ import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
model = (
|
||||
Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
|
||||
.to(torch_device, dtype=torch.float16)
|
||||
.eval()
|
||||
)
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
image_processor = self.default_image_processor
|
||||
|
@ -22,7 +22,14 @@ import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -509,6 +516,20 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
|
||||
.to(torch_device, dtype=torch.float16)
|
||||
.eval()
|
||||
)
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
|
@ -22,7 +22,14 @@ import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import OneFormerConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
model = (
|
||||
OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
|
||||
.to(torch_device, dtype=torch.float16)
|
||||
.eval()
|
||||
)
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
|
||||
processor = self.default_processor
|
||||
|
Loading…
Reference in New Issue
Block a user