adapting transfo tokenizer to transposed inputs

This commit is contained in:
thomwolf 2019-02-11 13:30:04 +01:00
parent 884ca81d87
commit e8fe6b7140

View File

@ -356,7 +356,10 @@ class LMOrderedIterator(object):
data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len]
return data, target, seq_len
data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
return data_out, target_out, seq_len
def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt):
@ -440,10 +443,10 @@ class LMShuffledIterator(object):
if not valid_batch:
return
data = data.to(self.device)
target = target.to(self.device)
data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
yield data, target, self.bptt
yield data_out, target_out, self.bptt
n_retain = min(data.size(0), self.ext_len)
if n_retain > 0: