mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add FP32 cast in ConvNext LayerNorm to prevent rounding errors with FP16 input (#18746)
* Adding cast to fp32 in convnext layernorm to prevent rounding errors in the case of fp16 input * Trigger CI
This commit is contained in:
parent
532ca05079
commit
d63bdf78d4
@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module):
|
|||||||
if self.data_format == "channels_last":
|
if self.data_format == "channels_last":
|
||||||
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
elif self.data_format == "channels_first":
|
elif self.data_format == "channels_first":
|
||||||
|
input_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
u = x.mean(1, keepdim=True)
|
u = x.mean(1, keepdim=True)
|
||||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||||
x = (x - u) / torch.sqrt(s + self.eps)
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||||||
|
x = x.to(dtype=input_dtype)
|
||||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user