mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Remove my unhelpful comments :)
This commit is contained in:
parent
fda2f62395
commit
01520d5412
@ -621,9 +621,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
|
||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||
# We just flatten the tokens out this way.
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
|
@ -720,9 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
|
||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||
# We just flatten the tokens out this way.
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
|
Loading…
Reference in New Issue
Block a user