mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Minor changes
This commit is contained in:
parent
7469d03b1c
commit
821de121e8
@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
|
||||
def train_custom(self):
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
pass
|
||||
self.classifier_head.train()
|
||||
|
||||
def avg_representation(self, x):
|
||||
@ -122,7 +121,7 @@ def collate_fn(data):
|
||||
padded_sequences = torch.zeros(
|
||||
len(sequences),
|
||||
max(lengths)
|
||||
).long() # padding index 0
|
||||
).long() # padding value = 0
|
||||
|
||||
for i, seq in enumerate(sequences):
|
||||
end = lengths[i]
|
||||
|
Loading…
Reference in New Issue
Block a user