mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix Gemma2 dtype issue when storing weights in float16 precision (#35398)
fix gemma2 dtype issue when storing weights in float16 precision
This commit is contained in:
parent
08ab1abff4
commit
9065cf0d92
@ -301,7 +301,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
|
@ -349,7 +349,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user