mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fixed num_channels!=3 normalization training (#20630)
* Fixed num_channels!=3 normalization training * empty commit to trigger CI * Empty-Commit for CircleCI * Empty-Commit * Empty Commit try-3: https://discuss.circleci.com/t/github-code-checkout-suddenly-failing/31558 * Empty commit to trigger CI Co-authored-by: Lay Jain <layjain@basil.csail.mit.edu> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
865da84abb
commit
44caf4f6f4
@ -819,11 +819,15 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
|
||||
loss = None
|
||||
with torch.no_grad():
|
||||
# calculate the labels to be predicted
|
||||
# first, unnormalize the frames
|
||||
device = pixel_values.device
|
||||
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
|
||||
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
|
||||
frames = pixel_values * std + mean # in [0, 1]
|
||||
if self.config.num_channels != 3:
|
||||
# Can't unnormalize with default means/stds
|
||||
frames = pixel_values
|
||||
else:
|
||||
# first, unnormalize the frames
|
||||
device = pixel_values.device
|
||||
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
|
||||
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
|
||||
frames = pixel_values * std + mean # in [0, 1]
|
||||
|
||||
batch_size, time, num_channels, height, width = frames.shape
|
||||
tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
|
||||
@ -859,6 +863,10 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
|
||||
tubelet_size * patch_size * patch_size * num_channels,
|
||||
)
|
||||
else:
|
||||
if self.config.num_channels != 3:
|
||||
raise ValueError(
|
||||
"Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
|
||||
)
|
||||
# step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
|
||||
frames = frames.view(
|
||||
batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user