mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix more flaky test_initialization
(#38932)
* try * try * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
5ee60f970a
commit
b8059e1f8f
@ -475,8 +475,19 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
# See PR #38607 (to avoid flakiness)
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[
|
||||
n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side
|
||||
]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
@ -311,8 +311,19 @@ class DepthProModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in non_uniform_init_parms):
|
||||
# See PR #38607 (to avoid flakiness)
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[
|
||||
n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side
|
||||
]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
@ -252,8 +252,17 @@ class Dinov2WithRegistersModelTest(ModelTesterMixin, PipelineTesterMixin, unitte
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad and "register_tokens" not in name:
|
||||
# See PR #38607 (to avoid flakiness)
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
@ -544,8 +544,19 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
# See PR #38607 (to avoid flakiness)
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[
|
||||
n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side
|
||||
]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
@ -249,8 +249,17 @@ class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
if "logit_scale" in name:
|
||||
continue
|
||||
if param.requires_grad:
|
||||
# See PR #38607 (to avoid flakiness)
|
||||
data = torch.flatten(param.data)
|
||||
n_elements = torch.numel(data)
|
||||
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||
data_to_check = torch.sort(data).values
|
||||
if n_elements_to_skip_on_each_side > 0:
|
||||
data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side]
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user