mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Improve test_initialization
(#38607)
* fix flaky init tests * fix flaky init tests --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
89542fb81c
commit
3e35ea1782
@ -716,8 +716,16 @@ class ModelTesterMixin:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user