use torch constraints to check if covariance is positive definite during mean resizing. (#35693)

* use torch constraints to check for psd

* small nit

* Small change

* Small change for the ci

* nit
This commit is contained in:
Mohamed Abu El-Nasr 2025-01-28 18:33:42 +02:00 committed by GitHub
parent 61cbb723fc
commit ec7afad609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,6 +37,7 @@ import torch
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
@ -2425,14 +2426,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
# Check if the covariance is positive definite.
eigenvalues = torch.linalg.eigvals(covariance)
is_covariance_psd = bool(
(covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all()
)
epsilon = 1e-9
is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
if is_covariance_psd:
# If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
distribution = torch.distributions.multivariate_normal.MultivariateNormal(
mean_embeddings, covariance_matrix=1e-9 * covariance
mean_embeddings, covariance_matrix=epsilon * covariance
)
new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
sample_shape=(added_num_tokens,)