mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ORTTrainer failure on gpt2 fp16 training (#18017)
* Ensure value and attn weights have the same dtype * Remove prints * Modify decision transformers copied from gpt2 * Nit device Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix style Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
2b09650885
commit
2844c5de10
@ -178,7 +178,9 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
|
@ -189,7 +189,9 @@ class GPT2Attention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
|
Loading…
Reference in New Issue
Block a user