mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Make EncodecModel.decode ONNX exportable (#29913)
* fix encodec onnx export for musicgen * simplification * fix quality * better style
This commit is contained in:
parent
b44df05bc0
commit
81642d2b51
@ -111,14 +111,27 @@ class EncodecConv1d(nn.Module):
|
|||||||
elif self.norm_type == "time_group_norm":
|
elif self.norm_type == "time_group_norm":
|
||||||
self.norm = nn.GroupNorm(1, out_channels)
|
self.norm = nn.GroupNorm(1, out_channels)
|
||||||
|
|
||||||
@staticmethod
|
kernel_size = self.conv.kernel_size[0]
|
||||||
|
stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
|
||||||
|
dilation = self.conv.dilation[0]
|
||||||
|
|
||||||
|
# Effective kernel size with dilations.
|
||||||
|
kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
|
||||||
|
|
||||||
|
self.register_buffer("stride", stride, persistent=False)
|
||||||
|
self.register_buffer("kernel_size", kernel_size, persistent=False)
|
||||||
|
self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
|
||||||
|
|
||||||
def _get_extra_padding_for_conv1d(
|
def _get_extra_padding_for_conv1d(
|
||||||
hidden_states: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
self,
|
||||||
) -> int:
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""See `pad_for_conv1d`."""
|
"""See `pad_for_conv1d`."""
|
||||||
length = hidden_states.shape[-1]
|
length = hidden_states.shape[-1]
|
||||||
n_frames = (length - kernel_size + padding_total) / stride + 1
|
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
|
||||||
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
n_frames = torch.ceil(n_frames).to(torch.int64) - 1
|
||||||
|
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
||||||
|
|
||||||
return ideal_length - length
|
return ideal_length - length
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -141,20 +154,15 @@ class EncodecConv1d(nn.Module):
|
|||||||
return padded[..., :end]
|
return padded[..., :end]
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
kernel_size = self.conv.kernel_size[0]
|
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||||
stride = self.conv.stride[0]
|
|
||||||
dilation = self.conv.dilation[0]
|
|
||||||
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
|
||||||
padding_total = kernel_size - stride
|
|
||||||
extra_padding = self._get_extra_padding_for_conv1d(hidden_states, kernel_size, stride, padding_total)
|
|
||||||
|
|
||||||
if self.causal:
|
if self.causal:
|
||||||
# Left padding for causal
|
# Left padding for causal
|
||||||
hidden_states = self._pad1d(hidden_states, (padding_total, extra_padding), mode=self.pad_mode)
|
hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
|
||||||
else:
|
else:
|
||||||
# Asymmetric padding required for odd strides
|
# Asymmetric padding required for odd strides
|
||||||
padding_right = padding_total // 2
|
padding_right = self.padding_total // 2
|
||||||
padding_left = padding_total - padding_right
|
padding_left = self.padding_total - padding_right
|
||||||
hidden_states = self._pad1d(
|
hidden_states = self._pad1d(
|
||||||
hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user