mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
adapting transfo tokenizer to transposed inputs
This commit is contained in:
parent
884ca81d87
commit
e8fe6b7140
@ -356,7 +356,10 @@ class LMOrderedIterator(object):
|
||||
data = self.data[beg_idx:end_idx]
|
||||
target = self.data[i+1:i+1+seq_len]
|
||||
|
||||
return data, target, seq_len
|
||||
data_out = data.transpose(0, 1).contiguous().to(self.device)
|
||||
target_out = target.transpose(0, 1).contiguous().to(self.device)
|
||||
|
||||
return data_out, target_out, seq_len
|
||||
|
||||
def get_fixlen_iter(self, start=0):
|
||||
for i in range(start, self.data.size(0) - 1, self.bptt):
|
||||
@ -440,10 +443,10 @@ class LMShuffledIterator(object):
|
||||
if not valid_batch:
|
||||
return
|
||||
|
||||
data = data.to(self.device)
|
||||
target = target.to(self.device)
|
||||
data_out = data.transpose(0, 1).contiguous().to(self.device)
|
||||
target_out = target.transpose(0, 1).contiguous().to(self.device)
|
||||
|
||||
yield data, target, self.bptt
|
||||
yield data_out, target_out, self.bptt
|
||||
|
||||
n_retain = min(data.size(0), self.ext_len)
|
||||
if n_retain > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user