mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Merge pull request #2540 from huggingface/torch14_fix
[PyTorch 1.4] Fix failing torchscript test for xlnet
This commit is contained in:
commit
880854846b
@ -236,11 +236,14 @@ class ModelTesterMixin:
|
|||||||
loaded_model.to(torch_device)
|
loaded_model.to(torch_device)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
model_params = model.parameters()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_params = loaded_model.parameters()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for p1, p2 in zip(model_params, loaded_model_params):
|
for layer_name, p1 in model_state_dict.items():
|
||||||
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
if p1.data.ne(p2.data).sum() > 0:
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
models_equal = False
|
models_equal = False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user