Fix torchscript tests for AltCLIP (#21102)

fix torchscript tests for AltCLIP

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-01-13 10:03:19 +01:00 committed by GitHub
parent b3a0aad37d
commit b210c83a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -490,6 +490,15 @@ class AltCLIPModelTest(ModelTesterMixin, unittest.TestCase):
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
models_equal = True