mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
adding shortcut to the ids of all the special tokens
This commit is contained in:
parent
306af132d7
commit
d51f72d5de
@ -679,13 +679,16 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
|
model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
|
||||||
tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!)
|
tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!)
|
||||||
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings to the new vocabulary size (add a vector at the end)
|
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings to the new vocabulary size (add a vector at the end)
|
||||||
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
||||||
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||||
mc_token_ids = torch.tensor([input_ids.size(-1), input_ids.size(-1)]).unsqueeze(0) # Batch size 1
|
mc_token_ids = torch.tensor([input_ids.size(-1)]) # Batch size 1
|
||||||
outputs = model(input_ids, mc_token_ids)
|
outputs = model(input_ids, mc_token_ids)
|
||||||
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
||||||
|
|
||||||
|
@ -128,8 +128,8 @@ class CommonTestCases:
|
|||||||
self.assertGreater(tokens[0], tokens[1])
|
self.assertGreater(tokens[0], tokens[1])
|
||||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
self.assertGreater(tokens[-2], tokens[-3])
|
self.assertGreater(tokens[-2], tokens[-3])
|
||||||
self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||||
self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
self.assertEqual(tokens[-2], tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
|
||||||
def test_required_methods_tokenizer(self):
|
def test_required_methods_tokenizer(self):
|
||||||
|
@ -155,6 +155,62 @@ class PreTrainedTokenizer(object):
|
|||||||
def additional_special_tokens(self, value):
|
def additional_special_tokens(self, value):
|
||||||
self._additional_special_tokens = value
|
self._additional_special_tokens = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token_id(self):
|
||||||
|
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._bos_token is None:
|
||||||
|
logger.error("Using bos_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._bos_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self):
|
||||||
|
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._eos_token is None:
|
||||||
|
logger.error("Using eos_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._eos_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_token_is(self):
|
||||||
|
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._unk_token is None:
|
||||||
|
logger.error("Using unk_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._unk_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep_token_id(self):
|
||||||
|
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
||||||
|
if self._sep_token is None:
|
||||||
|
logger.error("Using sep_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._sep_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token_id(self):
|
||||||
|
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._pad_token is None:
|
||||||
|
logger.error("Using pad_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._pad_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cls_token_id(self):
|
||||||
|
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
||||||
|
if self._cls_token is None:
|
||||||
|
logger.error("Using cls_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._cls_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_token_id(self):
|
||||||
|
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
||||||
|
if self._mask_token is None:
|
||||||
|
logger.error("Using mask_token, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._mask_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_special_tokens_ids(self):
|
||||||
|
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
|
||||||
|
if self._additional_special_tokens is None:
|
||||||
|
logger.error("Using additional_special_tokens, but it is not set yet.")
|
||||||
|
return self.convert_tokens_to_ids(self._additional_special_tokens)
|
||||||
|
|
||||||
def __init__(self, max_len=None, **kwargs):
|
def __init__(self, max_len=None, **kwargs):
|
||||||
self._bos_token = None
|
self._bos_token = None
|
||||||
self._eos_token = None
|
self._eos_token = None
|
||||||
|
Loading…
Reference in New Issue
Block a user