[SeamlessM4T] fix copies with NLLB MoE int8 (#27018)

fix copies on newly merged model
This commit is contained in:
Arthur 2023-10-23 15:25:06 +02:00 committed by GitHub
parent 244a53e0f6
commit f9f27b0fc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1300,7 +1300,7 @@ class SeamlessM4TFeedForwardNetwork(nn.Module):
if (
isinstance(self.fc2.weight, torch.Tensor)
and hidden_states.dtype != self.fc2.weight.dtype
and self.fc2.weight.dtype != torch.int8
and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8)
):
hidden_states = hidden_states.to(self.fc2.weight.dtype)
hidden_states = self.fc2(hidden_states)