Fix distiller.py (#12910)

* fix distiller

* fix style
This commit is contained in:
chutaklee 2021-07-29 02:11:38 +08:00 committed by GitHub
parent 1da782cb28
commit c164064eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -380,21 +380,19 @@ class Distiller:
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
"""
if self.mlm:
s_logits, s_hidden_states = self.student(
student_outputs = self.student(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, t_hidden_states = self.teacher(
teacher_outputs = self.teacher(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
else:
s_logits, _, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
assert s_logits.size() == t_logits.size()
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100