Actually the extra_id are from 0-99 and not from 1-100 (#5967)

a = tokenizer.encode("we got a <extra_id_99>", return_tensors='pt',add_special_tokens=True)
print(a)
>tensor([[   62,   530,     3,     9, 32000]])
a = tokenizer.encode("we got a <extra_id_100>", return_tensors='pt',add_special_tokens=True)
print(a)
>tensor([[   62,   530,     3,     9,     3,     2, 25666,   834,    23,    26,
           834,  2915,  3155]])
This commit is contained in:
Oren Amsalem 2020-07-30 13:13:29 +03:00 committed by GitHub
parent 3212b8850d
commit d24ea708d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -38,13 +38,13 @@ T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens)
and the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens.
Each sentinel token represents a unique mask token for this sentence and should start with ``<extra_id_1>``, ``<extra_id_2>``, ... up to ``<extra_id_100>``. As a default 100 sentinel tokens are available in ``T5Tokenizer``.
Each sentinel token represents a unique mask token for this sentence and should start with ``<extra_id_0>``, ``<extra_id_1>``, ... up to ``<extra_id_99>``. As a default 100 sentinel tokens are available in ``T5Tokenizer``.
*E.g.* the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be processed as follows:
::
input_ids = tokenizer.encode('The <extra_id_1> walks in <extra_id_2> park', return_tensors='pt')
labels = tokenizer.encode('<extra_id_1> cute dog <extra_id_2> the <extra_id_3> </s>', return_tensors='pt')
input_ids = tokenizer.encode('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt')
labels = tokenizer.encode('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt')
# the forward function automatically creates the correct decoder_input_ids
model(input_ids=input_ids, labels=labels)