MaskFormer,Mask2former - reduce memory load (#25741)

Allocate result array ahead of time
This commit is contained in:
amyeroberts 2023-08-29 18:49:15 +01:00 committed by GitHub
parent 0daeeb40a1
commit ce2d4bc6a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 20 deletions

View File

@ -2011,13 +2011,12 @@ class Mask2FormerMaskPredictor(nn.Module):
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Sum up over the channels
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
outputs_mask = (mask_embeddings * pixel_embeddings).sum(2)
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False

View File

@ -1789,13 +1789,15 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
class_queries_logits = classes[-1]
# get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# sum up over the channels for each embedding
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(0).unsqueeze(2)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks = (mask_embeddings * pixel_embeddings).sum(dim=3)
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
binaries_masks = torch.zeros(
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used
@ -1811,12 +1813,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
masks_queries_logits = (mask_embeddings * pixel_embeddings).sum(dim=2)
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
masks_queries_logits = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
return class_queries_logits, masks_queries_logits, auxiliary_logits