Fix Gemma2 dtype issue when storing weights in float16 precision (#35398)

fix gemma2 dtype issue when storing weights in float16 precision
This commit is contained in:
Nerogar 2025-02-13 11:17:37 +01:00 committed by GitHub
parent 08ab1abff4
commit 9065cf0d92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -301,7 +301,7 @@ class Gemma2DecoderLayer(nn.Module):
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)

View File

@ -349,7 +349,7 @@ class Gemma2DecoderLayer(nn.Module):
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)