Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (#31295)

Fix DonutSwinLayer attention mask device
This commit is contained in:
Alex Gorodnitskiy 2024-06-06 21:52:14 +01:00 committed by GitHub
parent b6c9f47fd6
commit 3b4d3d09fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 15 deletions

View File

@ -593,10 +593,10 @@ class ClapAudioLayer(nn.Module):
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
@ -661,9 +661,9 @@ class ClapAudioLayer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions

View File

@ -565,10 +565,10 @@ class DonutSwinLayer(nn.Module):
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
@ -633,9 +633,9 @@ class DonutSwinLayer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions

View File

@ -642,10 +642,10 @@ class SwinLayer(nn.Module):
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
@ -710,9 +710,9 @@ class SwinLayer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions