mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
20 lines
679 B
Python
20 lines
679 B
Python
import torch
|
|
|
|
|
|
class ClassificationHead(torch.nn.Module):
|
|
"""Classification Head for transformer encoders"""
|
|
|
|
def __init__(self, class_size, embed_size):
|
|
super(ClassificationHead, self).__init__()
|
|
self.class_size = class_size
|
|
self.embed_size = embed_size
|
|
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
|
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
|
self.mlp = torch.nn.Linear(embed_size, class_size)
|
|
|
|
def forward(self, hidden_state):
|
|
# hidden_state = F.relu(self.mlp1(hidden_state))
|
|
# hidden_state = self.mlp2(hidden_state)
|
|
logits = self.mlp(hidden_state)
|
|
return logits
|