mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Minor changes
This commit is contained in:
parent
7469d03b1c
commit
821de121e8
@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
|
|||||||
def train_custom(self):
|
def train_custom(self):
|
||||||
for param in self.encoder.parameters():
|
for param in self.encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
pass
|
|
||||||
self.classifier_head.train()
|
self.classifier_head.train()
|
||||||
|
|
||||||
def avg_representation(self, x):
|
def avg_representation(self, x):
|
||||||
@ -122,7 +121,7 @@ def collate_fn(data):
|
|||||||
padded_sequences = torch.zeros(
|
padded_sequences = torch.zeros(
|
||||||
len(sequences),
|
len(sequences),
|
||||||
max(lengths)
|
max(lengths)
|
||||||
).long() # padding index 0
|
).long() # padding value = 0
|
||||||
|
|
||||||
for i, seq in enumerate(sequences):
|
for i, seq in enumerate(sequences):
|
||||||
end = lengths[i]
|
end = lengths[i]
|
||||||
|
Loading…
Reference in New Issue
Block a user