mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
parent
1da782cb28
commit
c164064eef
@ -380,21 +380,19 @@ class Distiller:
|
|||||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||||
"""
|
"""
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
s_logits, s_hidden_states = self.student(
|
student_outputs = self.student(
|
||||||
input_ids=input_ids, attention_mask=attention_mask
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
) # (bs, seq_length, voc_size)
|
) # (bs, seq_length, voc_size)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
t_logits, t_hidden_states = self.teacher(
|
teacher_outputs = self.teacher(
|
||||||
input_ids=input_ids, attention_mask=attention_mask
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
) # (bs, seq_length, voc_size)
|
) # (bs, seq_length, voc_size)
|
||||||
else:
|
else:
|
||||||
s_logits, _, s_hidden_states = self.student(
|
student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||||
input_ids=input_ids, attention_mask=None
|
|
||||||
) # (bs, seq_length, voc_size)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
t_logits, _, t_hidden_states = self.teacher(
|
teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||||
input_ids=input_ids, attention_mask=None
|
s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
|
||||||
) # (bs, seq_length, voc_size)
|
t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
|
||||||
assert s_logits.size() == t_logits.size()
|
assert s_logits.size() == t_logits.size()
|
||||||
|
|
||||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||||
|
Loading…
Reference in New Issue
Block a user