transformers/tests/test_modeling_xlnet.py
Michael Benayoun 0fe17f375a
FX tracing improvement (#14321)
* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
2022-02-07 22:25:33 +01:00

1072 lines
33 KiB
Python

# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
from transformers import XLNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
import torch
from transformers import (
XLNetForMultipleChoice,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetLMHeadModel,
XLNetModel,
)
from transformers.models.xlnet.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
class XLNetModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
mem_len=10,
clamp_len=-1,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
same_length=False,
initializer_range=0.05,
seed=1,
type_vocab_size=2,
bos_token_id=1,
eos_token_id=2,
pad_token_id=5,
num_choices=4,
):
self.parent = parent
self.batch_size = 14
self.seq_length = 7
self.mem_len = 10
# self.key_len = seq_length + mem_len
self.clamp_len = -1
self.reuse_len = 15
self.is_training = True
self.use_labels = True
self.vocab_size = 99
self.cutoffs = [10, 50, 80]
self.hidden_size = 32
self.num_attention_heads = 4
self.d_inner = 128
self.num_hidden_layers = 5
self.type_sequence_label_size = 2
self.untie_r = True
self.bi_data = False
self.same_length = False
self.initializer_range = 0.05
self.seed = 1
self.type_vocab_size = 2
self.bos_token_id = 1
self.eos_token_id = 2
self.pad_token_id = 5
self.num_choices = 4
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(
self.batch_size,
self.seq_length + 1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(
self.batch_size,
1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None
lm_labels = None
is_impossible_labels = None
token_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = self.get_config()
return (
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
)
def get_config(self):
return XLNetConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data,
initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config)
model.to(torch_device)
model.eval()
result = model(input_ids_1, input_mask=input_mask)
result = model(input_ids_1, attention_mask=input_mask)
result = model(input_ids_1, token_type_ids=segment_ids)
result = model(input_ids_1)
config.mem_len = 0
model = XLNetModel(config)
model.to(torch_device)
model.eval()
base_model_output = model(input_ids_1)
self.parent.assertEqual(len(base_model_output), 2)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual(
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_use_mems_train(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.train()
train_size = input_ids_1.shape[0]
batch_size = 4
for i in range(train_size // batch_size + 1):
input_ids = input_ids_1[i : (i + 1) * batch_size]
labels = sequence_labels[i : (i + 1) * batch_size]
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
self.parent.assertIsNone(outputs.mems)
self.parent.assertIsNotNone(outputs.loss)
def create_and_check_xlnet_model_use_mems(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
causal_mask = torch.ones(
input_ids_1.shape[0],
input_ids_1.shape[1],
input_ids_1.shape[1],
dtype=torch.float,
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
output, mems = outputs_cache.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
# causal mask
causal_mask = torch.ones(
input_ids_1.shape[0],
input_ids_1.shape[1] + 1,
input_ids_1.shape[1] + 1,
dtype=torch.float,
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device)
# second forward pass
output_from_no_past = model(next_input_ids, perm_mask=causal_mask)["last_hidden_state"]
output_from_past = model(next_tokens, mems=mems, perm_mask=single_mask)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_xlnet_base_model_with_att_output(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config)
model.to(torch_device)
model.eval()
attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)["attentions"]
self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple)
self.parent.assertEqual(len(attentions[0]), 2)
self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
def create_and_check_xlnet_lm_head(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetLMHeadModel(config)
model.to(torch_device)
model.eval()
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
self.parent.assertEqual(result1.loss.shape, ())
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
[mem.shape for mem in result1.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertEqual(result2.loss.shape, ())
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
[mem.shape for mem in result2.mems],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_qa(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForQuestionAnswering(config)
model.to(torch_device)
model.eval()
result = model(input_ids_1)
result_with_labels = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask,
)
result_with_labels = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
)
total_loss, mems = result_with_labels.to_tuple()
result_with_labels = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
total_loss, mems = result_with_labels.to_tuple()
self.parent.assertEqual(result_with_labels.loss.shape, ())
self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
self.parent.assertEqual(
result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
)
self.parent.assertEqual(
result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
)
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
self.parent.assertListEqual(
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_token_classif(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids_1)
result = model(input_ids_1, labels=token_labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
self.parent.assertListEqual(
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_sequence_classif(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids_1)
result = model(input_ids_1, labels=sequence_labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertListEqual(
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict
@require_torch
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
XLNetLMHeadModel,
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
XLNetForMultipleChoice,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "XLNetForQuestionAnswering":
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self):
self.model_tester = XLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_xlnet_base_model(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_mems(self):
# checking that in auto-regressive mode, `use_mems` gives the same results
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
def test_seq_classification_use_mems_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
def test_xlnet_lm_head(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
def test_xlnet_sequence_classif(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
def test_xlnet_token_classif(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs)
def test_xlnet_qa(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# xlnet cannot keep gradients in attentions or hidden states
return
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# check hidden size
for i, layer_hidden_states in enumerate(iter_hidden_states):
# every 2nd tensor is from extra stream
if i % 2 != 0:
seq_len = 1
else:
# for first item dummy PAD token is appended so need one more
seq_len = (min_length + 1) if idx == 0 else min_length
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, attentions_item in enumerate(attentions):
for iter_attentions in attentions_item:
tgt_len = min_length
# for first item dummy PAD token is appended so need one more
if idx == 0:
tgt_len += 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions],
[expected_shape] * len(iter_attentions),
)
@slow
def test_model_from_pretrained(self):
for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = XLNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class XLNetModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
model.to(torch_device)
input_ids = torch.tensor(
[
[
67,
2840,
19,
18,
1484,
20,
965,
29077,
8719,
1273,
21,
45,
273,
17,
10,
15048,
28,
27511,
21,
4185,
11,
41,
2444,
9,
32,
1025,
20,
8719,
26,
23,
673,
966,
19,
29077,
20643,
27511,
20822,
20643,
19,
17,
6616,
17511,
18,
8978,
20,
18,
777,
9,
19233,
1527,
17669,
19,
24,
673,
17,
28756,
150,
12943,
4354,
153,
27,
442,
37,
45,
668,
21,
24,
256,
20,
416,
22,
2771,
4901,
9,
12943,
4354,
153,
51,
24,
3004,
21,
28142,
23,
65,
20,
18,
416,
34,
24,
2958,
22947,
9,
1177,
45,
668,
3097,
13768,
23,
103,
28,
441,
148,
48,
20522,
19,
12943,
4354,
153,
12860,
34,
18,
326,
27,
17492,
684,
21,
6709,
9,
8585,
123,
266,
19,
12943,
4354,
153,
6872,
24,
3004,
20,
18,
9225,
2198,
19,
12717,
103,
22,
401,
24,
6348,
9,
12943,
4354,
153,
1068,
2768,
2286,
19,
33,
104,
19,
176,
24,
9313,
19,
20086,
28,
45,
10292,
9,
4,
3,
]
],
dtype=torch.long,
device=torch_device,
)
# In 1991, the remains of Russian Tsar Nicholas II and his family
# (except for Alexei and Maria) are discovered.
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
# remainder of the story. 1883 Western Siberia,
# a young Grigori Rasputin is asked by his father and a group of men to perform magic.
# Rasputin has a vision and denounces one of the men as a horse thief. Although his
# father initially slaps him for making such an accusation, Rasputin watches as the
# man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
# the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
# with people, even a bishop, begging for his blessing. """
expected_output_ids = [
67,
2840,
19,
18,
1484,
20,
965,
29077,
8719,
1273,
21,
45,
273,
17,
10,
15048,
28,
27511,
21,
4185,
11,
41,
2444,
9,
32,
1025,
20,
8719,
26,
23,
673,
966,
19,
29077,
20643,
27511,
20822,
20643,
19,
17,
6616,
17511,
18,
8978,
20,
18,
777,
9,
19233,
1527,
17669,
19,
24,
673,
17,
28756,
150,
12943,
4354,
153,
27,
442,
37,
45,
668,
21,
24,
256,
20,
416,
22,
2771,
4901,
9,
12943,
4354,
153,
51,
24,
3004,
21,
28142,
23,
65,
20,
18,
416,
34,
24,
2958,
22947,
9,
1177,
45,
668,
3097,
13768,
23,
103,
28,
441,
148,
48,
20522,
19,
12943,
4354,
153,
12860,
34,
18,
326,
27,
17492,
684,
21,
6709,
9,
8585,
123,
266,
19,
12943,
4354,
153,
6872,
24,
3004,
20,
18,
9225,
2198,
19,
12717,
103,
22,
401,
24,
6348,
9,
12943,
4354,
153,
1068,
2768,
2286,
19,
33,
104,
19,
176,
24,
9313,
19,
20086,
28,
45,
10292,
9,
4,
3,
19,
12943,
4354,
153,
27,
442,
22,
2771,
4901,
9,
69,
27,
442,
22,
2771,
24,
11335,
20,
18,
9225,
2198,
9,
69,
27,
442,
22,
2771,
24,
11335,
20,
18,
9225,
2198,
9,
69,
27,
442,
22,
2771,
]
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
# narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin
# is asked by his father and a group of men to perform magic. Rasputin has a vision and
# denounces one of the men as a horse thief. Although his father initially slaps
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)