From d9fa1bad728f7dd78131b316686e18b8a8904196 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 15 Jan 2020 20:22:21 -0500 Subject: [PATCH] Fix failing torchscript test for xlnet model.parameters() order is apparently not stable (only for xlnet, for some reason) --- tests/test_modeling_common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 420ee6564e0..a5d69fbd6c1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -236,11 +236,14 @@ class ModelTesterMixin: loaded_model.to(torch_device) loaded_model.eval() - model_params = model.parameters() - loaded_model_params = loaded_model.parameters() + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) models_equal = True - for p1, p2 in zip(model_params, loaded_model_params): + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] if p1.data.ne(p2.data).sum() > 0: models_equal = False