mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (#31295)
Fix DonutSwinLayer attention mask device
This commit is contained in:
parent
b6c9f47fd6
commit
3b4d3d09fd
@ -593,10 +593,10 @@ class ClapAudioLayer(nn.Module):
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
@ -661,9 +661,9 @@ class ClapAudioLayer(nn.Module):
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
attn_mask = self.get_attn_mask(
|
||||
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||
)
|
||||
|
||||
attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
|
@ -565,10 +565,10 @@ class DonutSwinLayer(nn.Module):
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
@ -633,9 +633,9 @@ class DonutSwinLayer(nn.Module):
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
attn_mask = self.get_attn_mask(
|
||||
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||
)
|
||||
|
||||
attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
|
@ -642,10 +642,10 @@ class SwinLayer(nn.Module):
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
@ -710,9 +710,9 @@ class SwinLayer(nn.Module):
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
attn_mask = self.get_attn_mask(
|
||||
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||
)
|
||||
|
||||
attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
|
Loading…
Reference in New Issue
Block a user