mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
change constant torch.tensor to torch.full (#20061)
This commit is contained in:
parent
787620e2a2
commit
707b12a353
@ -170,8 +170,8 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
attn_weights = attn_weights / torch.full(
|
||||
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
@ -185,7 +185,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -182,8 +182,8 @@ class GPT2Attention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
attn_weights = attn_weights / torch.full(
|
||||
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
@ -197,7 +197,7 @@ class GPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user