test plm model

This commit is contained in:
JiwenJ 2025-04-20 05:35:07 +00:00
parent ef97fe7e0e
commit b525fba43f

View File

@ -415,121 +415,121 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_model(self): def test_model(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
# def test_model_various_embeddings(self): def test_model_various_embeddings(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
# for type in ["absolute", "relative_key", "relative_key_query"]: for type in ["absolute", "relative_key", "relative_key_query"]:
# config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
# self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
# @parameterized.expand([("yarn",)]) @parameterized.expand([("yarn",)])
# def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# short_input = ids_tensor([1, 10], config.vocab_size) short_input = ids_tensor([1, 10], config.vocab_size)
# long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
# set_seed(42) # Fixed seed at init time so the two models get the same random weights set_seed(42) # Fixed seed at init time so the two models get the same random weights
# original_model = PLMModel(config) original_model = PLMModel(config)
# original_model.to(torch_device) original_model.to(torch_device)
# original_model.eval() original_model.eval()
# original_short_output = original_model(short_input).last_hidden_state original_short_output = original_model(short_input).last_hidden_state
# original_long_output = original_model(long_input).last_hidden_state original_long_output = original_model(long_input).last_hidden_state
# set_seed(42) # Fixed seed at init time so the two models get the same random weights set_seed(42) # Fixed seed at init time so the two models get the same random weights
# config.rope_scaling = {"type": scaling_type, "factor": 10.0} config.rope_scaling = {"type": scaling_type, "factor": 10.0}
# scaled_model = PLMModel(config) scaled_model = PLMModel(config)
# scaled_model.to(torch_device) scaled_model.to(torch_device)
# scaled_model.eval() scaled_model.eval()
# scaled_short_output = scaled_model(short_input).last_hidden_state scaled_short_output = scaled_model(short_input).last_hidden_state
# scaled_long_output = scaled_model(long_input).last_hidden_state scaled_long_output = scaled_model(long_input).last_hidden_state
# # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# # maximum sequence length, so the outputs for the short input should match. # maximum sequence length, so the outputs for the short input should match.
# if scaling_type == "dynamic": if scaling_type == "dynamic":
# torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
# else: else:
# self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# # The output should be different for long inputs # The output should be different for long inputs
# self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# def test_model_rope_scaling(self): def test_model_rope_scaling(self):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# scaling_factor = 10 scaling_factor = 10
# short_input_length = 10 short_input_length = 10
# long_input_length = int(config.max_position_embeddings * 1.5) long_input_length = int(config.max_position_embeddings * 1.5)
# # Inputs # Inputs
# x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
# position_ids_short = position_ids_short.unsqueeze(0) position_ids_short = position_ids_short.unsqueeze(0)
# position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
# position_ids_long = position_ids_long.unsqueeze(0) position_ids_long = position_ids_long.unsqueeze(0)
# # Sanity check original RoPE # Sanity check original RoPE
# original_rope = PLMRotaryEmbedding(config=config).to(torch_device) original_rope = PLMRotaryEmbedding(config=config).to(torch_device)
# original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
# original_cos_long, original_sin_long = original_rope(x, position_ids_long) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
# torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
# torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format")
# def test_past_key_values_format(self): def test_past_key_values_format(self):
# pass pass
# @require_torch_sdpa @require_torch_sdpa
# @slow @slow
# def test_eager_matches_sdpa_generate(self): def test_eager_matches_sdpa_generate(self):
# """ """
# Overwritting the common test as the test is flaky on tiny models Overwritting the common test as the test is flaky on tiny models
# """ """
# max_new_tokens = 30 max_new_tokens = 30
# tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base")
# model_sdpa = PLMForCausalLM.from_pretrained( model_sdpa = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base", "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16, torch_dtype=torch.float16,
# low_cpu_mem_usage=True, low_cpu_mem_usage=True,
# trust_remote_code=True, trust_remote_code=True,
# ).to(torch_device) ).to(torch_device)
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
# model_eager = PLMForCausalLM.from_pretrained( model_eager = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base", "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16, torch_dtype=torch.float16,
# low_cpu_mem_usage=True, low_cpu_mem_usage=True,
# attn_implementation="eager", attn_implementation="eager",
# trust_remote_code=True, trust_remote_code=True,
# ).to(torch_device) ).to(torch_device)
# breakpoint() breakpoint()
# self.assertTrue(model_eager.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager")
# texts = [ texts = [
# "hi here's a longer context, getting longer and", "hi here's a longer context, getting longer and",
# "Hello this is a very long sentence my friend, very long for real", "Hello this is a very long sentence my friend, very long for real",
# "Today I am in Paris and", "Today I am in Paris and",
# ] ]
# for padding_side in ["left", "right"]: for padding_side in ["left", "right"]:
# tokenizer.padding_side = padding_side tokenizer.padding_side = padding_side
# tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
# res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# with self.subTest(f"{padding_side}"): with self.subTest(f"{padding_side}"):
# torch.testing.assert_close( torch.testing.assert_close(
# res_eager, res_eager,
# res_sdpa, res_sdpa,
# msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
# ) )