mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
adapting transfo tokenizer to transposed inputs
This commit is contained in:
parent
884ca81d87
commit
e8fe6b7140
@ -356,7 +356,10 @@ class LMOrderedIterator(object):
|
|||||||
data = self.data[beg_idx:end_idx]
|
data = self.data[beg_idx:end_idx]
|
||||||
target = self.data[i+1:i+1+seq_len]
|
target = self.data[i+1:i+1+seq_len]
|
||||||
|
|
||||||
return data, target, seq_len
|
data_out = data.transpose(0, 1).contiguous().to(self.device)
|
||||||
|
target_out = target.transpose(0, 1).contiguous().to(self.device)
|
||||||
|
|
||||||
|
return data_out, target_out, seq_len
|
||||||
|
|
||||||
def get_fixlen_iter(self, start=0):
|
def get_fixlen_iter(self, start=0):
|
||||||
for i in range(start, self.data.size(0) - 1, self.bptt):
|
for i in range(start, self.data.size(0) - 1, self.bptt):
|
||||||
@ -440,10 +443,10 @@ class LMShuffledIterator(object):
|
|||||||
if not valid_batch:
|
if not valid_batch:
|
||||||
return
|
return
|
||||||
|
|
||||||
data = data.to(self.device)
|
data_out = data.transpose(0, 1).contiguous().to(self.device)
|
||||||
target = target.to(self.device)
|
target_out = target.transpose(0, 1).contiguous().to(self.device)
|
||||||
|
|
||||||
yield data, target, self.bptt
|
yield data_out, target_out, self.bptt
|
||||||
|
|
||||||
n_retain = min(data.size(0), self.ext_len)
|
n_retain = min(data.size(0), self.ext_len)
|
||||||
if n_retain > 0:
|
if n_retain > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user