mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
parent
1da782cb28
commit
c164064eef
@ -380,21 +380,19 @@ class Distiller:
|
||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||
"""
|
||||
if self.mlm:
|
||||
s_logits, s_hidden_states = self.student(
|
||||
student_outputs = self.student(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, t_hidden_states = self.teacher(
|
||||
teacher_outputs = self.teacher(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
s_logits, _, s_hidden_states = self.student(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, _, t_hidden_states = self.teacher(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
|
||||
t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
|
Loading…
Reference in New Issue
Block a user