mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (#31295)
Fix DonutSwinLayer attention mask device
This commit is contained in:
parent
b6c9f47fd6
commit
3b4d3d09fd
@ -593,10 +593,10 @@ class ClapAudioLayer(nn.Module):
|
|||||||
self.shift_size = 0
|
self.shift_size = 0
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = min(input_resolution)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||||
height_slices = (
|
height_slices = (
|
||||||
slice(0, -self.window_size),
|
slice(0, -self.window_size),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size, -self.shift_size),
|
||||||
@ -661,9 +661,9 @@ class ClapAudioLayer(nn.Module):
|
|||||||
# partition windows
|
# partition windows
|
||||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
attn_mask = self.get_attn_mask(
|
||||||
if attn_mask is not None:
|
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
)
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||||
|
@ -565,10 +565,10 @@ class DonutSwinLayer(nn.Module):
|
|||||||
self.shift_size = 0
|
self.shift_size = 0
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = min(input_resolution)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||||
height_slices = (
|
height_slices = (
|
||||||
slice(0, -self.window_size),
|
slice(0, -self.window_size),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size, -self.shift_size),
|
||||||
@ -633,9 +633,9 @@ class DonutSwinLayer(nn.Module):
|
|||||||
# partition windows
|
# partition windows
|
||||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
attn_mask = self.get_attn_mask(
|
||||||
if attn_mask is not None:
|
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
)
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||||
|
@ -642,10 +642,10 @@ class SwinLayer(nn.Module):
|
|||||||
self.shift_size = 0
|
self.shift_size = 0
|
||||||
self.window_size = min(input_resolution)
|
self.window_size = min(input_resolution)
|
||||||
|
|
||||||
def get_attn_mask(self, height, width, dtype):
|
def get_attn_mask(self, height, width, dtype, device):
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
||||||
height_slices = (
|
height_slices = (
|
||||||
slice(0, -self.window_size),
|
slice(0, -self.window_size),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size, -self.shift_size),
|
||||||
@ -710,9 +710,9 @@ class SwinLayer(nn.Module):
|
|||||||
# partition windows
|
# partition windows
|
||||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||||
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
attn_mask = self.get_attn_mask(
|
||||||
if attn_mask is not None:
|
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
||||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
)
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||||
|
Loading…
Reference in New Issue
Block a user