Correct segment ID for XLNet single sequence

This commit is contained in:
Lysandre 2020-01-21 11:33:45 -05:00
parent 073219b43f
commit 088fa7b759

View File

@ -240,7 +240,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls_segment_id = [2]
if token_ids_1 is None:
return len(token_ids_0 + sep + cls) * [0]
return len(token_ids_0 + sep) * [0] + cls_segment_id
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
def save_vocabulary(self, save_directory):