diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e9274f1e54d..44d34c833b4 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module): if self.data_format == "channels_last": x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x