mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Make GPT2 traceable in meta state (#28054)
* Put device in tensor constructor instead of to() * Fix copy
This commit is contained in:
parent
e2b6df7971
commit
74cae670ce
@ -185,7 +185,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user