From 4c2538b863d8949a98d6b8dc1dea9ed4cf96a5df Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Jul 2024 15:22:43 +0100 Subject: [PATCH] Test loading generation config with safetensor weights (#31550) fix test --- tests/utils/test_modeling_utils.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 134b4758e63..c86c340017b 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1424,20 +1424,15 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) def test_generation_config_is_loaded_with_model(self): - # Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy - # `transformers_version` field set to `foo`. If loading the file fails, this test also fails. + # Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048` # 1. Load without further parameters - model = AutoModelForCausalLM.from_pretrained( - "joaogante/tiny-random-gpt2-with-generation-config", use_safetensors=False - ) - self.assertEqual(model.generation_config.transformers_version, "foo") + model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + self.assertEqual(model.generation_config.max_length, 2048) # 2. Load with `device_map` - model = AutoModelForCausalLM.from_pretrained( - "joaogante/tiny-random-gpt2-with-generation-config", device_map="auto", use_safetensors=False - ) - self.assertEqual(model.generation_config.transformers_version, "foo") + model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto") + self.assertEqual(model.generation_config.max_length, 2048) @require_safetensors def test_safetensors_torch_from_torch(self):