mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 14:50:07 +06:00

* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
20 lines
655 B
Python
20 lines
655 B
Python
import torch
|
|
|
|
|
|
class ClassificationHead(torch.nn.Module):
|
|
"""Classification Head for transformer encoders"""
|
|
|
|
def __init__(self, class_size, embed_size):
|
|
super().__init__()
|
|
self.class_size = class_size
|
|
self.embed_size = embed_size
|
|
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
|
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
|
self.mlp = torch.nn.Linear(embed_size, class_size)
|
|
|
|
def forward(self, hidden_state):
|
|
# hidden_state = F.relu(self.mlp1(hidden_state))
|
|
# hidden_state = self.mlp2(hidden_state)
|
|
logits = self.mlp(hidden_state)
|
|
return logits
|