[ViTMAE] Fix docstrings and variable names (#17710)

* Fix docstrings and variable names

* Rename x to something better

* Improve messages

* Fix docstrings and add test for greyscale images

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-06-21 15:56:00 +02:00 committed by GitHub
parent 3fab17fce8
commit b681e12d59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 185 additions and 57 deletions

View File

@ -84,7 +84,7 @@ class TFViTMAEDecoderOutput(ModelOutput):
Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
Args:
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
@ -109,7 +109,7 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
Args:
loss (`tf.Tensor` of shape `(1,)`):
Pixel reconstruction loss.
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
@ -969,50 +969,110 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def patchify(self, imgs):
def patchify(self, pixel_values):
"""
imgs: (batch_size, height, width, 3) x: (batch_size, num_patches, patch_size**2 *3)
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
imgs = tf.cond(
tf.math.equal(shape_list(imgs)[1], 3), lambda: tf.transpose(imgs, perm=(0, 2, 3, 1)), lambda: imgs
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# make sure channels are last
pixel_values = tf.cond(
tf.math.equal(shape_list(pixel_values)[1], num_channels),
lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
lambda: pixel_values,
)
p = self.vit.embeddings.patch_embeddings.patch_size[0]
tf.debugging.assert_equal(shape_list(imgs)[1], shape_list(imgs)[2])
tf.debugging.assert_equal(shape_list(imgs)[1] % p, 0)
# sanity checks
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1] % patch_size,
0,
message="Make sure the pixel values have a size that is divisible by the patch size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[3],
num_channels,
message=(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
),
)
h = w = shape_list(imgs)[2] // p
x = tf.reshape(imgs, (shape_list(imgs)[0], h, p, w, p, 3))
x = tf.einsum("nhpwqc->nhwpqc", x)
x = tf.reshape(x, (shape_list(imgs)[0], h * w, p**2 * 3))
return x
# patchify
batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape(
pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
)
return patchified_pixel_values
def unpatchify(self, x):
def unpatchify(self, patchified_pixel_values):
"""
x: (batch_size, num_patches, patch_size**2 *3) imgs: (batch_size, height, width, 3)
"""
p = self.vit.embeddings.patch_embeddings.patch_size[0]
h = w = int(shape_list(x)[1] ** 0.5)
tf.debugging.assert_equal(h * w, shape_list(x)[1])
Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
x = tf.reshape(x, (shape_list(x)[0], h, w, p, p, 3))
x = tf.einsum("nhwpqc->nhpwqc", x)
imgs = tf.reshape(x, (shape_list(x)[0], h * p, h * p, 3))
return imgs
Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
# sanity check
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
)
def forward_loss(self, imgs, pred, mask):
# unpatchify
batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
"""
imgs: [batch_size, height, width, 3] pred: [batch_size, num_patches, patch_size**2*3] mask: [N, L], 0 is keep,
1 is remove,
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
Pixel values.
pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`tf.Tensor`: Pixel reconstruction loss.
"""
target = self.patchify(imgs)
target = self.patchify(pixel_values)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = tf.reduce_mean(loss, axis=-1) # [N, L], mean loss per patch
loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
return loss

View File

@ -86,7 +86,7 @@ class ViTMAEDecoderOutput(ModelOutput):
Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
Args:
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
@ -111,7 +111,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`):
Pixel reconstruction loss.
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
@ -868,37 +868,86 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def patchify(self, imgs):
def patchify(self, pixel_values):
"""
imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3)
"""
p = self.vit.embeddings.patch_embeddings.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# sanity checks
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels:
raise ValueError(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
)
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W)
"""
p = self.vit.embeddings.patch_embeddings.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
# patchify
batch_size = pixel_values.shape[0]
num_patches_one_direction = pixel_values.shape[2] // patch_size
patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
)
return patchified_pixel_values
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def unpatchify(self, patchified_pixel_values):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
def forward_loss(self, imgs, pred, mask):
Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values.
"""
imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove,
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
# sanity check
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared")
# unpatchify
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_one_direction,
num_patches_one_direction,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_one_direction * patch_size,
num_patches_one_direction * patch_size,
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
"""
target = self.patchify(imgs)
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`torch.FloatTensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
@ -958,8 +1007,8 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
ids_restore = outputs.ids_restore
mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) # [N, L, p*p*3]
logits = decoder_outputs.logits
decoder_outputs = self.decoder(latent, ids_restore)
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
loss = self.forward_loss(pixel_values, logits, mask)

View File

@ -140,6 +140,15 @@ class TFViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = TFViTMAEForPreTraining(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, training=False)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, pixel_values, labels) = config_and_inputs

View File

@ -137,6 +137,16 @@ class ViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = ViTMAEForPreTraining(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs