Minor changes

This commit is contained in:
piero 2019-11-27 15:27:49 -08:00 committed by Julien Chaumond
parent 7469d03b1c
commit 821de121e8

View File

@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
def train_custom(self):
for param in self.encoder.parameters():
param.requires_grad = False
pass
self.classifier_head.train()
def avg_representation(self, x):
@ -122,7 +121,7 @@ def collate_fn(data):
padded_sequences = torch.zeros(
len(sequences),
max(lengths)
).long() # padding index 0
).long() # padding value = 0
for i, seq in enumerate(sequences):
end = lengths[i]