diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 2357c20e213..35b2aebf6c0 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -538,6 +538,17 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 28213de84df..fdae7768da1 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -509,6 +509,17 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 7d9c6b5ba58..9a6e3da06c7 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -496,8 +496,28 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -920,8 +940,28 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -1116,8 +1156,28 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 57f532da863..894f1e12799 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -647,6 +647,17 @@ class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clap/test_modeling_clap.py b/tests/models/clap/test_modeling_clap.py index 56a7dfdcfce..53b4082c68d 100644 --- a/tests/models/clap/test_modeling_clap.py +++ b/tests/models/clap/test_modeling_clap.py @@ -575,8 +575,28 @@ class ClapModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 98a3f838382..4239ce5ed0d 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -547,8 +547,28 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index 35dca117e2b..e931bdc8d56 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -528,8 +528,28 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index b6291aafedf..d678b707a59 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -229,6 +229,17 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 2544b7ee93f..e69d2226c89 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -964,8 +964,28 @@ class FlavaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # Non persistent buffers won't be in original state dict loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None) + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/graphormer/test_modeling_graphormer.py b/tests/models/graphormer/test_modeling_graphormer.py index e874ebf0f44..bc860e1a8aa 100644 --- a/tests/models/graphormer/test_modeling_graphormer.py +++ b/tests/models/graphormer/test_modeling_graphormer.py @@ -321,6 +321,17 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/groupvit/test_modeling_groupvit.py b/tests/models/groupvit/test_modeling_groupvit.py index f87cec426b0..261841277ad 100644 --- a/tests/models/groupvit/test_modeling_groupvit.py +++ b/tests/models/groupvit/test_modeling_groupvit.py @@ -630,8 +630,28 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index af08725ef18..19fe688bf4c 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -500,6 +500,17 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index acf078ffe80..4dbd1fb0a8c 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -491,8 +491,28 @@ class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -670,8 +690,28 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 8ec023676d6..c49db3dcff2 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -669,8 +669,28 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 16ad704fd51..d86fc43a826 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -712,8 +712,28 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 614b29b160d..c504c1005d2 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -814,8 +814,28 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 2efece44cae..ef2d11ac626 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -612,8 +612,28 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name]