Add FP32 cast in ConvNext LayerNorm to prevent rounding errors with FP16 input (#18746)

* Adding cast to fp32 in convnext layernorm to prevent rounding errors in the case of fp16 input

* Trigger CI
This commit is contained in:
Jim Briggs 2022-09-16 13:42:57 +01:00 committed by GitHub
parent 532ca05079
commit d63bdf78d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module):
if self.data_format == "channels_last": if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first": elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True) u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps) x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None] x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x return x