mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixed typos (issue 27919) (#27920)
* fixed typos (issue 27919) * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
e5079b0b2a
commit
e660424717
@ -61,8 +61,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class ImageDistilTrainer(Trainer):
|
||||
def __init__(self, *args, teacher_model=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
|
||||
super().__init__(model=student_model, *args, **kwargs)
|
||||
self.teacher = teacher_model
|
||||
self.student = student_model
|
||||
self.loss_function = nn.KLDivLoss(reduction="batchmean")
|
||||
@ -164,7 +164,7 @@ trainer = ImageDistilTrainer(
|
||||
train_dataset=processed_datasets["train"],
|
||||
eval_dataset=processed_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
tokenizer=teacher_extractor,
|
||||
tokenizer=teacher_processor,
|
||||
compute_metrics=compute_metrics,
|
||||
temperature=5,
|
||||
lambda_param=0.5
|
||||
|
Loading…
Reference in New Issue
Block a user