From e8fe6b7140a3b48c72a5ef528099d2518856124d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 11 Feb 2019 13:30:04 +0100 Subject: [PATCH] adapting transfo tokenizer to transposed inputs --- pytorch_pretrained_bert/tokenization_transfo_xl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index 585a8159239..3f74726f6fa 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -356,7 +356,10 @@ class LMOrderedIterator(object): data = self.data[beg_idx:end_idx] target = self.data[i+1:i+1+seq_len] - return data, target, seq_len + data_out = data.transpose(0, 1).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) + + return data_out, target_out, seq_len def get_fixlen_iter(self, start=0): for i in range(start, self.data.size(0) - 1, self.bptt): @@ -440,10 +443,10 @@ class LMShuffledIterator(object): if not valid_batch: return - data = data.to(self.device) - target = target.to(self.device) + data_out = data.transpose(0, 1).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) - yield data, target, self.bptt + yield data_out, target_out, self.bptt n_retain = min(data.size(0), self.ext_len) if n_retain > 0: