fix zoedepth initialization error under deepspeed zero3 (#35011)

fix zoe bug in deepspeed zero3
This commit is contained in:
Qizhi Chen 2024-12-20 19:42:40 +08:00 committed by GitHub
parent c3a43594b7
commit 4567ee8057
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -417,7 +417,7 @@ class LogBinomialSoftmax(nn.Module):
self.k = n_classes
self.act = act
self.register_buffer("k_idx", torch.arange(0, n_classes).view(1, -1, 1, 1), persistent=False)
self.register_buffer("k_minus_1", torch.Tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
self.register_buffer("k_minus_1", torch.tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
def forward(self, probabilities, temperature=1.0, eps=1e-4):
"""Compute the log binomial distribution for probabilities.