mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update: ignore padding support for TransfoXL training when n_clusters==0 (#22457)
* Update: ignore padding support for TransfoXL training when n_clusters==0 * Update: transformer XL always pad * Update: drop doc
This commit is contained in:
parent
2194943a34
commit
cd73b9a8c1
@ -86,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||
"""
|
||||
Params:
|
||||
hidden :: [len*bsz x d_proj]
|
||||
labels :: [len*bsz
|
||||
labels :: [len*bsz]
|
||||
|
||||
Return:
|
||||
if labels is None: out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary else: out ::
|
||||
@ -109,7 +109,11 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||
if self.n_clusters == 0:
|
||||
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
|
||||
if labels is not None:
|
||||
out = -nn.functional.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)
|
||||
mask = labels != -100
|
||||
out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
|
||||
out[mask] = (
|
||||
-nn.functional.log_softmax(logit, dim=-1)[mask].gather(1, labels[mask].unsqueeze(1)).squeeze(1)
|
||||
)
|
||||
else:
|
||||
out = nn.functional.log_softmax(logit, dim=-1)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user