From 3e35ea1782959d9e70d25048ca28faa516274236 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:08:05 +0200 Subject: [PATCH] Improve `test_initialization` (#38607) * fix flaky init tests * fix flaky init tests --------- Co-authored-by: ydshieh --- tests/test_modeling_common.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9011f341007..bc62d56894e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -716,8 +716,16 @@ class ModelTesterMixin: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if param.requires_grad: + data = torch.flatten(param.data) + n_elements = torch.numel(data) + # skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in + # https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332 + n_elements_to_skip_on_each_side = int(n_elements * 0.025) + data_to_check = torch.sort(data).values + if n_elements_to_skip_on_each_side > 0: + data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side] self.assertIn( - ((param.data.mean() * 1e9).round() / 1e9).item(), + ((data_to_check.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", )