Update: ignore padding support for TransfoXL training when n_clusters==0 (#22457)

* Update: ignore padding support for TransfoXL training when n_clusters==0

* Update: transformer XL always pad

* Update: drop doc
This commit is contained in:
Stefan Heng 2023-03-29 14:36:39 -04:00 committed by GitHub
parent 2194943a34
commit cd73b9a8c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -86,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
"""
Params:
hidden :: [len*bsz x d_proj]
labels :: [len*bsz
labels :: [len*bsz]
Return:
if labels is None: out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary else: out ::
@ -109,7 +109,11 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
if labels is not None:
out = -nn.functional.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)
mask = labels != -100
out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
out[mask] = (
-nn.functional.log_softmax(logit, dim=-1)[mask].gather(1, labels[mask].unsqueeze(1)).squeeze(1)
)
else:
out = nn.functional.log_softmax(logit, dim=-1)
else: