mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Improve test_initialization
(#38607)
* fix flaky init tests * fix flaky init tests --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
89542fb81c
commit
3e35ea1782
@ -716,8 +716,16 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
|
data = torch.flatten(param.data)
|
||||||
|
n_elements = torch.numel(data)
|
||||||
|
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||||
|
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||||
|
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||||
|
data_to_check = torch.sort(data).values
|
||||||
|
if n_elements_to_skip_on_each_side > 0:
|
||||||
|
data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side]
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user