test plm model

This commit is contained in:
JiwenJ 2025-04-20 05:35:07 +00:00
parent ef97fe7e0e
commit b525fba43f

View File

@ -415,121 +415,121 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.config_tester.run_common_tests()
def test_model(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_model(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# def test_model_various_embeddings(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# for type in ["absolute", "relative_key", "relative_key_query"]:
# config_and_inputs[0].position_embedding_type = type
# self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
# @parameterized.expand([("yarn",)])
# def test_model_rope_scaling_from_config(self, scaling_type):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# short_input = ids_tensor([1, 10], config.vocab_size)
# long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
@parameterized.expand([("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
# original_model = PLMModel(config)
# original_model.to(torch_device)
# original_model.eval()
# original_short_output = original_model(short_input).last_hidden_state
# original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = PLMModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
# config.rope_scaling = {"type": scaling_type, "factor": 10.0}
# scaled_model = PLMModel(config)
# scaled_model.to(torch_device)
# scaled_model.eval()
# scaled_short_output = scaled_model(short_input).last_hidden_state
# scaled_long_output = scaled_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = PLMModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# # maximum sequence length, so the outputs for the short input should match.
# if scaling_type == "dynamic":
# torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
# else:
# self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# # The output should be different for long inputs
# self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# def test_model_rope_scaling(self):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# scaling_factor = 10
# short_input_length = 10
# long_input_length = int(config.max_position_embeddings * 1.5)
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# # Inputs
# x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
# position_ids_short = position_ids_short.unsqueeze(0)
# position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
# position_ids_long = position_ids_long.unsqueeze(0)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# # Sanity check original RoPE
# original_rope = PLMRotaryEmbedding(config=config).to(torch_device)
# original_cos_short, original_sin_short = original_rope(x, position_ids_short)
# original_cos_long, original_sin_long = original_rope(x, position_ids_long)
# torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
# torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check original RoPE
original_rope = PLMRotaryEmbedding(config=config).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format")
# def test_past_key_values_format(self):
# pass
@unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
# @require_torch_sdpa
# @slow
# def test_eager_matches_sdpa_generate(self):
# """
# Overwritting the common test as the test is flaky on tiny models
# """
# max_new_tokens = 30
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30
# tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base")
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base")
# model_sdpa = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# ).to(torch_device)
model_sdpa = PLMForCausalLM.from_pretrained(
"PLM-Team/PLM-1.8B-Base",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to(torch_device)
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
# model_eager = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
# attn_implementation="eager",
# trust_remote_code=True,
# ).to(torch_device)
# breakpoint()
# self.assertTrue(model_eager.config._attn_implementation == "eager")
model_eager = PLMForCausalLM.from_pretrained(
"PLM-Team/PLM-1.8B-Base",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
trust_remote_code=True,
).to(torch_device)
breakpoint()
self.assertTrue(model_eager.config._attn_implementation == "eager")
# texts = [
# "hi here's a longer context, getting longer and",
# "Hello this is a very long sentence my friend, very long for real",
# "Today I am in Paris and",
# ]
texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]
# for padding_side in ["left", "right"]:
# tokenizer.padding_side = padding_side
# tokenizer.pad_token = tokenizer.eos_token
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
# inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
# res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# with self.subTest(f"{padding_side}"):
# torch.testing.assert_close(
# res_eager,
# res_sdpa,
# msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
# )
with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)