mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
MaskFormer,Mask2former - reduce memory load (#25741)
Allocate result array ahead of time
This commit is contained in:
parent
0daeeb40a1
commit
ce2d4bc6a1
@ -2011,13 +2011,12 @@ class Mask2FormerMaskPredictor(nn.Module):
|
||||
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
|
||||
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
|
||||
|
||||
# Sum up over the channels
|
||||
# (batch_size, num_queries, num_channels, 1, 1)
|
||||
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
|
||||
# (batch_size, 1, num_channels, height, width)
|
||||
pixel_embeddings = pixel_embeddings.unsqueeze(1)
|
||||
# (batch_size, num_queries, height, width)
|
||||
outputs_mask = (mask_embeddings * pixel_embeddings).sum(2)
|
||||
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
|
||||
batch_size, num_queries, num_channels = mask_embeddings.shape
|
||||
_, _, height, width = pixel_embeddings.shape
|
||||
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
|
||||
for c in range(num_channels):
|
||||
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
|
||||
|
||||
attention_mask = nn.functional.interpolate(
|
||||
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
|
||||
|
@ -1789,13 +1789,15 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
class_queries_logits = classes[-1]
|
||||
# get the masks
|
||||
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
|
||||
# sum up over the channels for each embedding
|
||||
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
|
||||
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
|
||||
# (1, batch_size, 1, num_channels, height, width)
|
||||
pixel_embeddings = pixel_embeddings.unsqueeze(0).unsqueeze(2)
|
||||
# (num_embeddings, batch_size, num_queries, height, width)
|
||||
binaries_masks = (mask_embeddings * pixel_embeddings).sum(dim=3)
|
||||
|
||||
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
|
||||
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
|
||||
_, _, height, width = pixel_embeddings.shape
|
||||
binaries_masks = torch.zeros(
|
||||
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
|
||||
)
|
||||
for c in range(num_channels):
|
||||
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
|
||||
|
||||
masks_queries_logits = binaries_masks[-1]
|
||||
# go til [:-1] because the last one is always used
|
||||
@ -1811,12 +1813,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
# get the masks
|
||||
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
|
||||
# sum up over the channels
|
||||
# (batch_size, num_queries, num_channels, 1, 1)
|
||||
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
|
||||
# (batch_size, 1, num_channels, height, width)
|
||||
pixel_embeddings = pixel_embeddings.unsqueeze(1)
|
||||
# (batch_size, num_queries, height, width)
|
||||
masks_queries_logits = (mask_embeddings * pixel_embeddings).sum(dim=2)
|
||||
|
||||
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
|
||||
batch_size, num_queries, num_channels = mask_embeddings.shape
|
||||
_, _, height, width = pixel_embeddings.shape
|
||||
masks_queries_logits = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
|
||||
for c in range(num_channels):
|
||||
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
|
||||
|
||||
return class_queries_logits, masks_queries_logits, auxiliary_logits
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user