mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00

* enable csm test cases on XPU, all passed Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
693 lines
42 KiB
Python
693 lines
42 KiB
Python
# coding=utf-8
|
|
# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
|
|
|
|
import collections
|
|
import copy
|
|
import re
|
|
import unittest
|
|
|
|
import pytest
|
|
from parameterized import parameterized
|
|
|
|
from transformers import (
|
|
AutoProcessor,
|
|
CsmConfig,
|
|
CsmForConditionalGeneration,
|
|
is_torch_available,
|
|
)
|
|
from transformers.testing_utils import (
|
|
cleanup,
|
|
require_torch_accelerator,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils.import_utils import is_datasets_available
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import (
|
|
ModelTesterMixin,
|
|
_config_zero_init,
|
|
ids_tensor,
|
|
)
|
|
|
|
|
|
if is_datasets_available():
|
|
from datasets import load_dataset
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers.pytorch_utils import id_tensor_storage
|
|
|
|
|
|
class CsmModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
ignore_index=-100,
|
|
batch_size=3,
|
|
seq_length=7,
|
|
is_training=True,
|
|
depth_decoder_config={
|
|
"num_codebooks": 10,
|
|
"backbone_hidden_size": 64,
|
|
"vocab_size": 6,
|
|
"hidden_size": 64,
|
|
"intermediate_size": 128,
|
|
"num_hidden_layers": 2,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 2,
|
|
"hidden_act": "silu",
|
|
"max_position_embeddings": 10,
|
|
},
|
|
codec_config={
|
|
"model_type": "mimi",
|
|
"audio_channels": 1,
|
|
"chunk_in_sec": None,
|
|
"hidden_size": 32,
|
|
"num_filters": 8,
|
|
"num_residual_layers": 1,
|
|
"upsampling_ratios": [8, 4],
|
|
"codebook_size": 64,
|
|
"vector_quantization_hidden_dimension": 64,
|
|
"upsample_groups": 32,
|
|
"num_hidden_layers": 2,
|
|
"num_attention_heads": 2,
|
|
"num_key_value_heads": 2,
|
|
"sliding_window": 4,
|
|
"codebook_dim": 64,
|
|
"use_cache": False,
|
|
},
|
|
config={
|
|
"num_codebooks": 10,
|
|
"vocab_size": 6,
|
|
"text_vocab_size": 99,
|
|
"hidden_size": 64,
|
|
"intermediate_size": 64,
|
|
"num_hidden_layers": 2,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 2,
|
|
"hidden_act": "silu",
|
|
"max_position_embeddings": 10,
|
|
"bos_token_id": 1,
|
|
"pad_token_id": 2,
|
|
"eos_token_id": 3,
|
|
"codebook_pad_token_id": 2,
|
|
"codebook_eos_token_id": 3,
|
|
},
|
|
):
|
|
self.parent = parent
|
|
self.is_training = is_training
|
|
self.ignore_index = ignore_index
|
|
self.depth_decoder_config = depth_decoder_config
|
|
self.codec_config = codec_config
|
|
self.config = config
|
|
self.seq_length = seq_length
|
|
self.batch_size = batch_size
|
|
|
|
self.num_hidden_layers = config["num_hidden_layers"]
|
|
self.vocab_size = config["vocab_size"]
|
|
self.hidden_size = config["hidden_size"]
|
|
self.num_attention_heads = config["num_attention_heads"]
|
|
self.pad_token_id = config["pad_token_id"]
|
|
|
|
def get_config(self):
|
|
return CsmConfig(
|
|
depth_decoder_config=self.depth_decoder_config,
|
|
codec_config=self.codec_config,
|
|
**self.config,
|
|
)
|
|
|
|
def prepare_config_and_inputs(self):
|
|
config = self.get_config()
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length, config.num_codebooks], config.vocab_size - 1) + 1
|
|
attention_mask = input_ids[..., -1].ne(1).to(torch_device)
|
|
return config, input_ids, attention_mask
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, input_ids, attention_mask = self.prepare_config_and_inputs()
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
all_model_classes = (CsmForConditionalGeneration,) if is_torch_available() else ()
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_resize_embeddings = False
|
|
test_resize_embeddings_untied = False
|
|
test_torch_exportable = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = CsmModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=CsmConfig)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
"""
|
|
Overrides [ModelTesterMixin._prepare_for_class] to handle third input_ids dimension.
|
|
"""
|
|
inputs_dict = copy.deepcopy(inputs_dict)
|
|
|
|
if return_labels:
|
|
inputs_dict["labels"] = torch.zeros(
|
|
(
|
|
self.model_tester.batch_size,
|
|
self.model_tester.seq_length,
|
|
self.model_tester.config["num_codebooks"],
|
|
),
|
|
dtype=torch.long,
|
|
device=torch_device,
|
|
)
|
|
|
|
return inputs_dict
|
|
|
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
|
"""
|
|
Overrides [GenerationTesterMixin._get_logits_processor_kwargs] to restrict to top_k, top_p, and temperature sampling.
|
|
"""
|
|
logits_processor_kwargs = {}
|
|
if do_sample:
|
|
logits_processor_kwargs.update(
|
|
{
|
|
"top_k": 10,
|
|
"top_p": 0.7,
|
|
"temperature": 0.7,
|
|
}
|
|
)
|
|
|
|
return logits_processor_kwargs
|
|
|
|
def test_initialization(self):
|
|
"""
|
|
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
|
|
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
|
|
"""
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
uniform_init_parms = ["conv", "input_proj", "output_proj"]
|
|
if param.requires_grad:
|
|
if any(x in name for x in uniform_init_parms):
|
|
self.assertTrue(
|
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
|
"""
|
|
Overrides [GenerationTesterMixin._check_similar_generate_outputs] to handle third input_ids dimension.
|
|
Here we only look a the first codebook (index 0 on last dimension of the generated sequences) since returned scores
|
|
are for this token.
|
|
"""
|
|
# scores doesn't include data regarding decoder input tokens
|
|
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
|
output_matches = output_1.sequences[..., 0] == output_2.sequences[..., 0]
|
|
has_matching_outputs = output_matches.all()
|
|
has_matching_scores = None
|
|
if not has_matching_outputs:
|
|
for batch_idx in range(output_1.sequences.shape[0]):
|
|
batch_matches = output_matches[batch_idx]
|
|
if batch_matches.all():
|
|
continue
|
|
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
|
first_mismatch_idx -= decoder_input_length
|
|
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
|
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
|
has_matching_scores = torch.allclose(
|
|
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
|
)
|
|
if not has_matching_scores:
|
|
break
|
|
self.assertTrue(has_matching_outputs or has_matching_scores)
|
|
|
|
@parameterized.expand([("random",), ("same",)])
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support assisted decoding.")
|
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support assisted decoding.")
|
|
def test_assisted_decoding_sample(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support Dola decoding.")
|
|
def test_dola_decoding_sample(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_beam_sample_generate(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_beam_search_generate(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_beam_search_generate_dict_output(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_beam_sample_generate_dict_output(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support group beam search.")
|
|
def test_group_beam_search_generate(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support group beam search.")
|
|
def test_group_beam_search_generate_dict_output(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support constrained beam search.")
|
|
def test_constrained_beam_search_generate(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support constrained beam search.")
|
|
def test_constrained_beam_search_generate_dict_output(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support contrastive search.")
|
|
def test_contrastive_generate(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support contrastive search.")
|
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support contrastive search.")
|
|
def test_contrastive_generate_low_memory(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
|
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
|
|
def test_prompt_lookup_decoding_stops_at_eos(self):
|
|
pass
|
|
|
|
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
|
|
def test_tie_model_weights(self):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="CSM does not support beam search.")
|
|
def test_model_parallel_beam_search(self):
|
|
pass
|
|
|
|
def test_tied_weights_keys(self):
|
|
"""
|
|
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
|
|
"""
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
for model_class in self.all_model_classes:
|
|
model_tied = model_class(config)
|
|
|
|
ptrs = collections.defaultdict(list)
|
|
for name, tensor in model_tied.state_dict().items():
|
|
ptrs[id_tensor_storage(tensor)].append(name)
|
|
|
|
# These are all the pointers of shared tensors.
|
|
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
|
|
|
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
|
# Detect we get a hit for each key
|
|
for key in tied_weight_keys:
|
|
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
|
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
|
|
|
# Removed tied weights found from tied params -> there should only be one left after
|
|
for key in tied_weight_keys:
|
|
for i in range(len(tied_params)):
|
|
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
|
|
|
tied_params = [group for group in tied_params if len(group) > 1]
|
|
self.assertListEqual(
|
|
tied_params,
|
|
[],
|
|
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
|
)
|
|
|
|
def _get_custom_4d_mask_test_data(self):
|
|
"""
|
|
Overrides [ModelTesterMixin._get_custom_4d_mask_test_data] to handle third input_ids dimension.
|
|
"""
|
|
# Sequence in which all but the last token is the same
|
|
input_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 2, 5]], device=torch_device, dtype=torch.int64)
|
|
input_ids = input_ids.unsqueeze(-1).expand(-1, -1, self.model_tester.config["num_codebooks"])
|
|
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
|
|
|
# Combining common prefix with the unique ending tokens:
|
|
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
|
|
|
|
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
|
mask_shared_prefix = torch.tensor(
|
|
[
|
|
[
|
|
[
|
|
[1, 0, 0, 0, 0, 0],
|
|
[1, 1, 0, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0],
|
|
[1, 1, 1, 1, 0, 0],
|
|
[1, 1, 1, 0, 1, 0],
|
|
[1, 1, 1, 0, 0, 1],
|
|
]
|
|
]
|
|
],
|
|
)
|
|
# inverting the attention mask
|
|
mask_dtype = torch.float32
|
|
min_dtype = torch.finfo(mask_dtype).min
|
|
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
|
|
|
|
# Creating a position_ids tensor. note the repeating figures in the end.
|
|
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
|
|
|
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
|
|
|
|
|
class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|
def setUp(self):
|
|
# TODO: @eustlb, update with correct sesame's repo
|
|
self.model_checkpoint = "eustlb/csm-1b"
|
|
|
|
def tearDown(self):
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
def _load_conversation(self):
|
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
|
ds = ds.filter(lambda x: x["conversation_id"] == 0)
|
|
ds = ds.sort("turn_id")
|
|
return ds[0]
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_1b_model_integration_generate(self):
|
|
"""
|
|
Tests the generated tokens match the ones from the original model implementation.
|
|
Such tokens are to be retreived using https://gist.github.com/eustlb/d25577a357ddcf8f4a8cd0d00baca551, which is a script that infers the original model.
|
|
"""
|
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
|
prompt = "<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"
|
|
|
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
|
audio = ds[0]["audio"]["array"]
|
|
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(torch_device)
|
|
|
|
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
|
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
|
[1140, 1818, 86, 1072, 1029, 1010, 796, 577, 1523, 1599, 902, 1308, 817, 232, 1860, 56, 327, 1399, 1069, 1014, 1980, 53, 407, 1841, 1559, 928, 972, 1432, 832, 1007, 1325, 371],
|
|
[955, 1390, 1503, 861, 265, 1753, 91, 1690, 389, 1025, 1086, 495, 1192, 1334, 773, 1277, 957, 1388, 513, 1110, 539, 349, 1865, 1515, 806, 1514, 237, 1424, 1783, 1928, 523, 1925],
|
|
[1925, 190, 654, 1538, 19, 37, 1923, 100, 1909, 1156, 1847, 1901, 975, 982, 2002, 544, 1933, 311, 79, 850, 238, 1034, 428, 1231, 764, 313, 973, 269, 1669, 1058, 1641, 891],
|
|
[1721, 92, 1298, 989, 1868, 154, 386, 1115, 347, 384, 853, 1439, 970, 1369, 238, 1279, 268, 595, 2010, 1861, 723, 999, 578, 1612, 69, 121, 306, 1647, 1609, 1185, 1786, 1268],
|
|
[1356, 1419, 1199, 1575, 418, 53, 1140, 805, 355, 324, 633, 199, 343, 1176, 784, 41, 268, 366, 1478, 466, 1591, 305, 1298, 1335, 1866, 1563, 1503, 1558, 1468, 852, 1244, 312],
|
|
[1860, 1603, 546, 1805, 607, 160, 1528, 191, 1867, 1830, 861, 661, 1740, 1276, 218, 954, 1286, 1216, 1727, 1637, 983, 597, 1857, 65, 797, 947, 427, 476, 739, 978, 107, 1394],
|
|
[1165, 1775, 177, 823, 100, 370, 521, 200, 2007, 434, 1444, 1205, 819, 1278, 31, 912, 150, 1546, 2035, 1147, 559, 1995, 639, 35, 1812, 56, 1485, 2003, 1573, 1693, 1762, 1313],
|
|
[1932, 704, 907, 897, 56, 1587, 990, 1905, 2007, 256, 671, 868, 282, 1731, 460, 1055, 1309, 1880, 584, 1849, 1643, 1198, 310, 361, 789, 1657, 905, 1564, 1354, 110, 915, 1011],
|
|
[1437, 1958, 1483, 313, 79, 28, 859, 397, 1783, 1693, 633, 1424, 1128, 1831, 605, 1123, 1496, 739, 1177, 498, 781, 1756, 1288, 890, 224, 1875, 279, 800, 1999, 1740, 348, 1420],
|
|
[724, 870, 1344, 861, 429, 522, 1877, 1689, 771, 1468, 1952, 156, 856, 462, 18, 834, 33, 840, 1136, 2012, 1766, 1891, 2034, 1731, 624, 108, 1469, 653, 1344, 1682, 407, 515],
|
|
[355, 26, 36, 1700, 1032, 293, 1799, 978, 944, 296, 1333, 1377, 664, 1249, 421, 516, 1178, 531, 1587, 899, 1, 1449, 934, 942, 1604, 1208, 1889, 710, 825, 2012, 1563, 1299],
|
|
[629, 15, 551, 861, 310, 918, 149, 1689, 1464, 1950, 1900, 1502, 1503, 615, 477, 1090, 1556, 1393, 1143, 1112, 1934, 416, 1604, 1470, 1501, 1594, 903, 1400, 972, 199, 1075, 1643],
|
|
[1281, 106, 1162, 1313, 115, 429, 1792, 1379, 1535, 1311, 743, 484, 333, 498, 547, 699, 1075, 1861, 1038, 1352, 166, 622, 759, 1398, 241, 138, 1330, 481, 1254, 1365, 985, 423],
|
|
[9, 520, 323, 25, 1873, 716, 1414, 1413, 266, 1449, 1265, 290, 1341, 836, 674, 411, 913, 911, 637, 1038, 1097, 1158, 1009, 803, 737, 154, 1388, 938, 466, 725, 1216, 1549],
|
|
[1944, 15, 62, 332, 540, 689, 106, 1805, 1303, 1787, 1724, 1011, 1515, 1442, 1197, 496, 2026, 1820, 906, 372, 322, 1413, 1305, 1674, 443, 1733, 828, 905, 1116, 1850, 1870, 786],
|
|
[221, 220, 1093, 1790, 759, 1266, 1169, 1379, 572, 1859, 1155, 596, 1398, 412, 1788, 1963, 167, 89, 1011, 1489, 714, 73, 486, 780, 1136, 254, 983, 138, 386, 800, 1819, 1857],
|
|
[1178, 1939, 107, 1605, 582, 1256, 420, 637, 648, 1023, 1809, 978, 1703, 278, 1668, 2044, 1599, 1321, 1670, 1716, 1155, 56, 602, 877, 886, 220, 910, 797, 1028, 1226, 869, 811],
|
|
[1432, 1926, 1197, 1687, 540, 1815, 658, 1080, 1162, 192, 315, 1713, 422, 586, 65, 947, 493, 1536, 13, 505, 1269, 456, 1042, 645, 512, 1394, 1124, 590, 1058, 1896, 1055, 1537],
|
|
[905, 564, 1739, 1594, 1201, 1773, 738, 994, 239, 1686, 1528, 368, 1791, 1924, 607, 44, 1320, 552, 1862, 1578, 591, 1434, 330, 1576, 1946, 1233, 113, 445, 669, 2041, 1242, 1406],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
|
]])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_1b_model_integration_generate_no_audio(self):
|
|
"""
|
|
Tests the generated tokens match the ones from the original model implementation.
|
|
Such tokens are to be retreived using https://gist.github.com/eustlb/aed822f765e928b9612e01b0d8836d69, which is a script that infers the original model.
|
|
"""
|
|
|
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
|
|
|
conversation = [
|
|
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
|
|
]
|
|
|
|
inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(torch_device)
|
|
|
|
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
|
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
|
|
|
print(output_tokens)
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
|
[1656, 629, 723, 1785, 206, 1873, 1059, 1190, 1833, 240, 618, 350, 156, 109, 2010, 452, 435, 1764, 77, 654, 1133, 908, 1095, 74, 804, 494, 1760, 1343, 1312, 1464, 1657, 324],
|
|
[366, 1532, 1945, 21, 145, 1428, 1417, 1987, 1793, 1444, 356, 1491, 849, 333, 788, 426, 1423, 1004, 414, 1823, 1169, 257, 1892, 696, 1572, 998, 1098, 523, 390, 1977, 546, 1692],
|
|
[1343, 1382, 1288, 1744, 1685, 1154, 1837, 1156, 1680, 1641, 1479, 1548, 632, 824, 694, 2010, 671, 1251, 1822, 343, 638, 1372, 696, 1272, 144, 125, 1332, 579, 936, 77, 159, 357],
|
|
[456, 1534, 349, 274, 1956, 1502, 1268, 1038, 1911, 523, 1360, 1159, 761, 293, 718, 1143, 63, 705, 168, 550, 413, 1372, 1771, 787, 631, 693, 784, 1789, 2039, 1131, 1601, 918],
|
|
[456, 829, 2026, 1108, 1649, 207, 1308, 1440, 1192, 1394, 426, 546, 590, 36, 1682, 1827, 1387, 1425, 1909, 1500, 1438, 1297, 5, 888, 948, 1745, 1304, 1364, 1692, 131, 300, 1908],
|
|
[2027, 1431, 1037, 1789, 1296, 1264, 1331, 1787, 1235, 1902, 1161, 1591, 590, 561, 1633, 1218, 510, 148, 1962, 118, 212, 608, 565, 1869, 583, 598, 532, 658, 1416, 9, 1172, 493],
|
|
[1215, 460, 1722, 317, 1423, 716, 1589, 1177, 1927, 1860, 1756, 1552, 1674, 643, 74, 1256, 587, 1742, 771, 2028, 469, 1070, 1683, 1614, 699, 494, 2020, 139, 1365, 1171, 171, 904],
|
|
[1615, 339, 323, 317, 469, 714, 104, 2015, 1407, 278, 468, 77, 2007, 650, 1630, 269, 168, 934, 1544, 58, 1487, 1373, 705, 874, 1252, 2031, 1995, 254, 1334, 1171, 1911, 1607],
|
|
[1259, 693, 666, 1700, 1115, 607, 982, 769, 1106, 1500, 101, 88, 1698, 1864, 1358, 1594, 192, 153, 1868, 1654, 604, 1948, 526, 778, 172, 1664, 1966, 99, 1334, 1030, 1349, 1209],
|
|
[1211, 579, 1369, 492, 1725, 203, 1125, 778, 701, 1982, 1420, 155, 736, 1145, 2018, 609, 658, 561, 1147, 923, 1794, 1753, 116, 1374, 612, 956, 1587, 392, 1062, 2047, 901, 1931],
|
|
[460, 1093, 1346, 1917, 1223, 470, 271, 390, 547, 112, 143, 1633, 1030, 643, 96, 1759, 920, 1959, 75, 1280, 1630, 999, 333, 853, 1110, 1291, 1911, 57, 171, 1658, 1704, 1508],
|
|
[908, 500, 393, 184, 1437, 482, 2008, 1834, 356, 1435, 1550, 1407, 1236, 109, 1167, 452, 1141, 934, 207, 957, 660, 670, 28, 1066, 1252, 1932, 669, 906, 1904, 1820, 2043, 881],
|
|
[1599, 1031, 1474, 336, 1540, 571, 437, 1440, 1616, 1365, 1412, 1246, 400, 405, 1776, 96, 296, 38, 1597, 466, 1630, 1256, 1940, 887, 1769, 294, 285, 842, 1756, 1619, 451, 1529],
|
|
[1615, 339, 1722, 525, 942, 105, 1365, 670, 785, 1316, 465, 1860, 438, 968, 547, 1938, 1816, 1429, 1065, 1942, 660, 1446, 1093, 1066, 931, 121, 688, 1033, 1178, 754, 1783, 94],
|
|
[912, 1354, 598, 254, 341, 1980, 1166, 585, 1302, 473, 554, 242, 174, 2030, 2011, 325, 978, 1690, 258, 396, 1831, 1768, 1291, 1699, 2001, 433, 1414, 2012, 1045, 511, 533, 1104],
|
|
[80, 1791, 1062, 1136, 391, 568, 1651, 101, 959, 2043, 1683, 760, 794, 181, 570, 540, 1599, 20, 1017, 973, 1654, 396, 586, 778, 2044, 1664, 1911, 929, 66, 897, 510, 643],
|
|
[1161, 1093, 161, 1296, 589, 54, 906, 981, 1927, 605, 516, 1731, 1461, 1204, 1902, 920, 1488, 177, 805, 1402, 610, 1446, 1154, 1067, 2025, 645, 762, 1715, 415, 1658, 1713, 1607],
|
|
[374, 1444, 1577, 792, 1450, 628, 604, 1729, 322, 514, 1725, 540, 1070, 575, 653, 800, 250, 187, 569, 349, 354, 1573, 176, 793, 897, 359, 536, 276, 1224, 23, 145, 1287],
|
|
[1184, 415, 1644, 1737, 1788, 385, 784, 1861, 1172, 1118, 367, 1156, 234, 1946, 1742, 981, 828, 1798, 1821, 361, 1148, 670, 518, 1288, 761, 1050, 1642, 1006, 1747, 840, 1599, 720],
|
|
[1141, 1731, 1670, 1542, 1347, 1907, 683, 753, 1347, 68, 2031, 153, 556, 719, 736, 1759, 1131, 1073, 1747, 1730, 1487, 1137, 1869, 1624, 699, 1900, 748, 49, 1312, 735, 726, 1268],
|
|
[1141, 1383, 405, 1033, 490, 488, 1102, 471, 713, 1630, 447, 703, 1495, 1001, 1855, 354, 456, 411, 786, 853, 168, 407, 116, 699, 605, 128, 532, 1076, 208, 447, 1448, 1071],
|
|
[345, 1013, 948, 1728, 1837, 337, 930, 1226, 1643, 1729, 983, 1688, 2009, 435, 1358, 721, 42, 1779, 1332, 1077, 1873, 128, 1327, 125, 1226, 1704, 705, 1459, 1449, 862, 155, 1870],
|
|
[336, 904, 684, 184, 1542, 714, 1752, 1180, 1373, 1816, 504, 1716, 1066, 1086, 1212, 530, 1413, 1278, 75, 1347, 82, 1623, 1307, 1717, 1861, 494, 888, 1589, 670, 1999, 905, 1430],
|
|
[578, 554, 14, 523, 1016, 300, 1589, 1017, 356, 1583, 1654, 414, 449, 376, 1413, 58, 706, 963, 388, 1626, 131, 352, 1024, 1054, 2025, 1561, 77, 1589, 1486, 431, 1249, 1508],
|
|
[184, 2043, 169, 1673, 580, 162, 1752, 397, 1119, 2009, 697, 150, 1475, 157, 1523, 1402, 575, 86, 1373, 1230, 1564, 1308, 626, 1093, 1603, 1446, 1390, 1543, 1778, 1142, 1357, 1831],
|
|
[1484, 1987, 932, 1728, 1504, 1618, 291, 1865, 1151, 460, 1792, 141, 234, 2043, 829, 513, 435, 791, 1037, 1541, 65, 424, 1589, 1711, 312, 1306, 212, 686, 673, 984, 1914, 1549],
|
|
[513, 1536, 1844, 1319, 572, 1069, 121, 735, 1949, 1211, 1362, 1027, 105, 1379, 315, 1782, 706, 1658, 1510, 1989, 1443, 1690, 822, 1614, 1194, 1460, 992, 2040, 1178, 1474, 1110, 1326],
|
|
[1858, 194, 1594, 1935, 1622, 1892, 1577, 137, 1907, 2015, 757, 414, 1823, 836, 496, 530, 1385, 1503, 1065, 1554, 664, 525, 1031, 433, 69, 466, 1016, 1846, 1609, 1658, 911, 94],
|
|
[1134, 1744, 323, 691, 1837, 347, 1871, 172, 811, 91, 1883, 436, 1912, 23, 1336, 1684, 519, 1612, 1219, 1402, 728, 1953, 1658, 641, 27, 1340, 436, 139, 2008, 1030, 159, 324],
|
|
[1270, 1536, 1639, 414, 1387, 1170, 1067, 1701, 1414, 505, 1122, 36, 1731, 350, 1552, 1214, 1444, 30, 107, 172, 480, 1858, 655, 168, 1107, 691, 1272, 797, 1656, 548, 1407, 1375],
|
|
[1270, 286, 1371, 1552, 1622, 1739, 1348, 2018, 345, 1537, 1941, 2024, 1423, 740, 284, 513, 91, 1228, 2015, 385, 992, 39, 813, 803, 2025, 497, 663, 462, 1609, 334, 927, 1470],
|
|
[1718, 994, 265, 1421, 1622, 1098, 845, 1868, 832, 459, 447, 619, 1970, 929, 513, 63, 1448, 1509, 1219, 1942, 285, 1373, 1259, 1004, 11, 1040, 1984, 57, 188, 1687, 1475, 805],
|
|
[1157, 832, 480, 1225, 1019, 347, 326, 999, 125, 1542, 118, 1383, 1343, 1077, 1821, 1602, 1978, 1642, 618, 808, 692, 1953, 1353, 963, 619, 1291, 1016, 1458, 1995, 1688, 1872, 1718],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
|
]])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_1b_model_integration_generate_multiple_audio(self):
|
|
"""
|
|
Test the generated tokens match the ones from the original model implementation.
|
|
Such tokens are to be retreived using https://gist.github.com/eustlb/0c94de002e1325abb61d32217f74c0f8, which is a script that infers the original model.
|
|
"""
|
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
|
|
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
|
conversation = []
|
|
|
|
# context
|
|
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
|
conversation.append(
|
|
{
|
|
"role": f"{speaker_id}",
|
|
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
|
}
|
|
)
|
|
|
|
# text prompt
|
|
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
|
|
|
|
inputs = processor.apply_chat_template(
|
|
conversation,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
).to(torch_device)
|
|
|
|
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
|
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
|
[420, 1189, 1311, 318, 359, 694, 1550, 1044, 1614, 1437, 1978, 537, 554, 1681, 147, 1225, 422, 1357, 1681, 1619, 165, 641, 1132, 1975, 1568, 406, 756, 503, 1673, 1428, 762, 781],
|
|
[1848, 1412, 957, 1656, 871, 540, 1999, 175, 711, 1383, 1814, 104, 742, 1285, 733, 1251, 1165, 1915, 1392, 645, 1804, 913, 1772, 632, 376, 1507, 1132, 725, 716, 1121, 1769, 1509],
|
|
[429, 1138, 895, 1018, 1099, 257, 1395, 1015, 576, 1599, 497, 19, 1858, 1437, 282, 357, 1143, 828, 1481, 70, 985, 551, 935, 278, 1102, 1453, 1902, 755, 526, 498, 1441, 1733],
|
|
[546, 343, 1547, 879, 2039, 692, 1999, 1150, 1969, 1866, 1178, 199, 1913, 1738, 1530, 1728, 1193, 74, 695, 612, 1095, 1597, 1381, 683, 1385, 2045, 1069, 865, 438, 70, 1437, 318],
|
|
[1741, 1621, 733, 1580, 1006, 482, 1508, 1722, 1529, 1822, 745, 552, 142, 1568, 704, 480, 214, 552, 321, 1858, 1902, 1042, 1249, 1328, 1730, 1218, 1755, 597, 670, 738, 1056, 762],
|
|
[1264, 1561, 1307, 730, 1403, 688, 212, 949, 1871, 994, 1174, 674, 858, 293, 1577, 1221, 1024, 1535, 1224, 872, 509, 1971, 46, 440, 1531, 1100, 1466, 732, 964, 381, 1933, 1612],
|
|
[1407, 982, 1665, 1247, 1636, 1546, 939, 882, 1999, 618, 484, 1632, 66, 430, 290, 327, 351, 1236, 687, 504, 1973, 1073, 1233, 1972, 82, 1655, 361, 1612, 861, 1085, 880, 1407],
|
|
[584, 637, 304, 1805, 1683, 1381, 404, 862, 1278, 916, 1695, 370, 316, 1049, 237, 1187, 1389, 300, 680, 135, 1068, 1368, 810, 1392, 103, 1459, 1051, 644, 38, 1517, 790, 646],
|
|
[471, 1984, 1333, 553, 193, 319, 1604, 1546, 153, 513, 990, 839, 1714, 1998, 984, 1882, 1055, 476, 1821, 1476, 1522, 1817, 949, 1923, 1416, 1885, 1832, 1368, 1782, 1229, 436, 918],
|
|
[28, 1238, 489, 1580, 596, 1232, 840, 835, 297, 762, 474, 1106, 1761, 483, 1165, 923, 1184, 1181, 1724, 398, 1484, 860, 1945, 665, 1925, 14, 67, 1693, 1853, 1283, 1822, 1973],
|
|
[20, 637, 253, 1254, 738, 188, 593, 1239, 1768, 1047, 1703, 1512, 1398, 464, 13, 161, 651, 1844, 666, 210, 1510, 1798, 614, 1649, 1751, 341, 808, 915, 1965, 840, 778, 950],
|
|
[1879, 2028, 1405, 694, 432, 2036, 612, 387, 1843, 1204, 1044, 8, 1538, 542, 1198, 598, 1131, 760, 1217, 901, 800, 1046, 136, 639, 1320, 618, 606, 707, 574, 1288, 1254, 198],
|
|
[1874, 937, 1063, 1341, 254, 13, 359, 888, 1837, 1246, 980, 818, 2046, 1258, 1290, 1470, 2028, 1701, 228, 1766, 51, 93, 296, 991, 1094, 1694, 156, 1207, 401, 967, 867, 211],
|
|
[1762, 426, 1749, 2004, 314, 903, 1254, 220, 1330, 1813, 534, 102, 658, 1460, 603, 1046, 402, 2005, 783, 973, 1764, 210, 1458, 803, 605, 369, 669, 352, 1964, 1549, 632, 1375],
|
|
[1577, 386, 503, 1492, 604, 405, 1329, 349, 180, 875, 329, 196, 514, 1854, 925, 159, 1428, 1300, 1510, 329, 76, 1682, 1036, 854, 695, 1097, 816, 382, 1417, 697, 1693, 194],
|
|
[1109, 848, 1385, 126, 1136, 979, 687, 130, 2045, 140, 562, 361, 921, 1706, 1060, 1723, 165, 1304, 203, 1067, 158, 692, 980, 313, 1896, 1812, 839, 837, 985, 116, 866, 1049],
|
|
[1810, 1092, 1534, 1730, 773, 2044, 1098, 1326, 85, 249, 455, 1728, 860, 443, 1841, 1885, 1698, 864, 1747, 1083, 1591, 1785, 1577, 1001, 1025, 1837, 1504, 1839, 1900, 1932, 230, 968],
|
|
[1547, 1465, 896, 794, 613, 1383, 1806, 1984, 526, 671, 100, 519, 2037, 1631, 1724, 633, 824, 994, 893, 1448, 1793, 1237, 1855, 699, 349, 143, 270, 535, 1550, 101, 22, 1311],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
|
]])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_1b_model_integration_generate_batched(self):
|
|
"""
|
|
Test the generated tokens match the ones from the original model implementation.
|
|
Such tokens are to be retreived using https://gist.github.com/eustlb/bcc532b53161bc31da3d66cb07ae193f, which is a script that infers the original model.
|
|
"""
|
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
|
|
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
|
conversation = [
|
|
[
|
|
{
|
|
"role": f"{ds[0]['speaker_id']}",
|
|
"content": [
|
|
{"type": "text", "text": ds[0]["text"]},
|
|
{"type": "audio", "path": ds[0]["audio"]["array"]},
|
|
],
|
|
},
|
|
{
|
|
"role": f"{ds[1]['speaker_id']}",
|
|
"content": [
|
|
{"type": "text", "text": ds[1]["text"]},
|
|
],
|
|
},
|
|
],
|
|
[
|
|
{
|
|
"role": f"{ds[0]['speaker_id']}",
|
|
"content": [
|
|
{"type": "text", "text": ds[0]["text"]},
|
|
],
|
|
}
|
|
],
|
|
]
|
|
|
|
inputs = processor.apply_chat_template(
|
|
conversation,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
).to(torch_device)
|
|
|
|
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
|
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS = torch.tensor([
|
|
[
|
|
[1140, 1818, 1713, 1072, 1029, 1185, 697, 358, 220, 481, 1127, 1779, 817, 891, 958, 1058, 672, 495, 426, 1135, 236, 1440, 829, 2023, 1097, 94, 926, 1830, 114, 307, 235, 1190],
|
|
[955, 968, 696, 676, 52, 618, 0, 1818, 1285, 143, 1733, 1268, 1317, 1510, 1027, 2033, 1276, 1744, 790, 638, 1179, 1125, 650, 266, 1180, 364, 1015, 1604, 1152, 154, 178, 284],
|
|
[1925, 274, 433, 273, 1391, 1528, 1683, 1120, 976, 944, 357, 1681, 847, 1783, 546, 857, 1662, 1695, 40, 152, 2039, 1076, 994, 1743, 265, 1751, 602, 981, 483, 981, 538, 1381],
|
|
[1908, 1625, 1975, 729, 1067, 1844, 837, 1849, 224, 1223, 1037, 1188, 1428, 1977, 317, 530, 990, 1670, 766, 1411, 811, 154, 433, 1645, 1565, 1291, 1390, 49, 1160, 1464, 1911, 1961],
|
|
[1908, 566, 175, 1387, 1437, 1873, 1785, 1536, 961, 414, 406, 1753, 835, 284, 764, 1522, 1889, 1816, 840, 440, 756, 860, 1753, 516, 601, 1498, 280, 1425, 1904, 1540, 1074, 314],
|
|
[1860, 296, 1766, 361, 1155, 1675, 528, 1975, 1286, 113, 1656, 237, 372, 580, 1571, 1958, 502, 893, 1300, 261, 313, 455, 693, 1658, 654, 1585, 1723, 721, 178, 679, 908, 1077],
|
|
[1165, 1787, 1877, 1904, 85, 609, 1007, 1724, 1959, 245, 645, 463, 1321, 1695, 192, 711, 1892, 1193, 302, 1835, 69, 940, 148, 913, 110, 108, 1244, 1510, 165, 726, 745, 1746],
|
|
[1405, 1410, 186, 1569, 1214, 1920, 1946, 1907, 990, 1152, 1401, 1713, 541, 115, 423, 616, 1191, 1149, 1122, 9, 303, 195, 906, 566, 1718, 668, 1637, 1975, 51, 2005, 1260, 1672],
|
|
[1932, 780, 143, 110, 286, 1460, 1136, 1366, 1788, 446, 645, 587, 1708, 189, 1295, 526, 1667, 735, 707, 1215, 27, 834, 1865, 182, 1776, 1130, 528, 1523, 1156, 316, 492, 1666],
|
|
[1437, 364, 314, 432, 575, 1640, 529, 1128, 973, 789, 1820, 808, 1317, 1681, 347, 471, 737, 1626, 1386, 75, 433, 517, 365, 1982, 1434, 1378, 1059, 56, 1475, 653, 1507, 861],
|
|
[724, 538, 1140, 1853, 76, 402, 0, 397, 330, 1787, 1382, 682, 1134, 296, 377, 997, 705, 627, 1700, 17, 1791, 1000, 1271, 1019, 1552, 1521, 668, 534, 433, 344, 1007, 1046],
|
|
[925, 1297, 1017, 1785, 1403, 520, 1603, 1908, 665, 1827, 951, 1588, 1526, 414, 1945, 1153, 1933, 1571, 1821, 104, 179, 769, 619, 117, 56, 790, 721, 992, 1284, 1495, 1459, 823],
|
|
[629, 1208, 689, 924, 1617, 1100, 1028, 1231, 1708, 1582, 200, 2011, 1611, 1966, 1153, 1326, 2036, 1515, 884, 1790, 581, 549, 1491, 701, 973, 836, 2031, 1249, 1411, 365, 1946, 1552],
|
|
[1281, 1305, 610, 1666, 676, 544, 1788, 315, 159, 809, 1333, 1785, 1159, 1084, 1356, 318, 1933, 854, 475, 638, 1616, 1801, 1816, 1921, 283, 1745, 814, 974, 1056, 1316, 1509, 2031],
|
|
[9, 212, 1590, 163, 1289, 923, 2046, 1620, 632, 127, 963, 405, 850, 471, 1430, 108, 1845, 1196, 1928, 143, 1717, 1054, 1288, 1351, 1340, 1294, 831, 480, 1562, 2004, 483, 1776],
|
|
[221, 142, 1555, 1434, 1481, 1371, 1873, 1607, 207, 631, 1042, 1084, 472, 465, 1772, 1002, 1761, 1912, 1298, 1918, 685, 1053, 1635, 1536, 497, 55, 1432, 1394, 1512, 365, 2026, 1210],
|
|
[1741, 1923, 930, 1423, 1258, 1227, 879, 1217, 1999, 422, 420, 1832, 1660, 1542, 92, 2000, 1790, 1909, 56, 695, 704, 1752, 371, 792, 625, 328, 567, 1397, 1557, 390, 1424, 14],
|
|
[1178, 812, 577, 895, 1386, 339, 1467, 844, 235, 703, 551, 2021, 1592, 1042, 353, 621, 1672, 653, 2029, 103, 766, 182, 2016, 1921, 556, 1092, 1579, 626, 1950, 70, 1467, 850],
|
|
[1352, 472, 577, 351, 1126, 1943, 52, 2028, 430, 1017, 1136, 645, 820, 2028, 723, 1385, 1922, 323, 106, 267, 438, 1064, 202, 1249, 244, 1962, 625, 1380, 476, 924, 1221, 1854],
|
|
[905, 811, 374, 2021, 1067, 675, 927, 427, 416, 1521, 663, 77, 457, 1849, 1362, 262, 1669, 1238, 286, 102, 555, 1809, 1585, 1918, 972, 1446, 688, 523, 1904, 943, 17, 904],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
|
],
|
|
[
|
|
[1375, 203, 265, 164, 200, 1867, 976, 924, 1972, 1637, 1048, 271, 1912, 1430, 853, 1942, 260, 1642, 400, 57, 1376, 1626, 1821, 1163, 619, 777, 1076, 951, 389, 1820, 84, 1417],
|
|
[914, 527, 286, 968, 305, 1314, 805, 1703, 87, 559, 1980, 1124, 1726, 36, 1139, 618, 1628, 519, 1943, 781, 400, 1265, 438, 113, 87, 856, 465, 162, 1099, 352, 1141, 274],
|
|
[1408, 6, 126, 2009, 90, 996, 934, 134, 1857, 126, 602, 876, 1092, 1962, 1205, 828, 707, 1063, 393, 1533, 123, 1086, 1749, 1324, 1, 1763, 1707, 1191, 34, 1323, 1017, 1787],
|
|
[1000, 683, 1630, 703, 1574, 587, 25, 1049, 213, 1270, 1641, 1072, 1892, 1634, 1603, 90, 867, 2037, 1021, 715, 206, 507, 1138, 959, 1822, 1785, 280, 1100, 1660, 251, 1903, 988],
|
|
[1657, 1981, 246, 1048, 1952, 451, 305, 423, 2000, 416, 756, 1748, 7, 748, 1866, 1795, 1682, 1832, 338, 212, 1685, 518, 154, 1407, 416, 765, 776, 25, 55, 458, 612, 262],
|
|
[1034, 564, 667, 1474, 1212, 350, 712, 941, 1151, 1182, 1280, 640, 924, 1722, 1816, 458, 226, 359, 1518, 102, 1203, 459, 676, 1788, 1110, 393, 1974, 1721, 795, 1459, 798, 1723],
|
|
[742, 1616, 119, 653, 441, 679, 246, 1432, 486, 1615, 1191, 500, 650, 223, 687, 1765, 1875, 963, 1385, 863, 151, 1771, 458, 1170, 737, 1932, 785, 1954, 1067, 16, 1986, 2029],
|
|
[1437, 1078, 1767, 1452, 1392, 45, 2010, 1664, 245, 2015, 1416, 1055, 457, 985, 740, 1594, 1562, 1838, 258, 1431, 701, 604, 1813, 352, 792, 632, 21, 895, 70, 609, 850, 1599],
|
|
[983, 1961, 54, 135, 846, 711, 473, 1630, 1373, 1094, 251, 525, 632, 1014, 1594, 1594, 1752, 398, 1266, 1357, 942, 1680, 191, 874, 483, 1291, 381, 1873, 1964, 1278, 1477, 122],
|
|
[1663, 1969, 1887, 113, 145, 251, 1133, 156, 245, 1641, 209, 1322, 2037, 836, 539, 667, 940, 797, 1758, 1357, 191, 1137, 587, 1699, 27, 701, 395, 99, 1682, 876, 762, 839],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
]
|
|
])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|