fixed typos (issue 27919) (#27920)

* fixed typos (issue 27919)

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Anthony Susevski 2023-12-11 18:44:23 -05:00 committed by GitHub
parent e5079b0b2a
commit e660424717
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -61,8 +61,8 @@ import torch.nn.functional as F
class ImageDistilTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
super().__init__(model=student_model, *args, **kwargs)
self.teacher = teacher_model
self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean")
@ -164,7 +164,7 @@ trainer = ImageDistilTrainer(
train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"],
data_collator=data_collator,
tokenizer=teacher_extractor,
tokenizer=teacher_processor,
compute_metrics=compute_metrics,
temperature=5,
lambda_param=0.5