mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
use torch constraints to check if covariance is positive definite during mean resizing. (#35693)
* use torch constraints to check for psd * small nit * Small change * Small change for the ci * nit
This commit is contained in:
parent
61cbb723fc
commit
ec7afad609
@ -37,6 +37,7 @@ import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from packaging import version
|
||||
from torch import Tensor, nn
|
||||
from torch.distributions import constraints
|
||||
from torch.nn import CrossEntropyLoss, Identity
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
@ -2425,14 +2426,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
|
||||
|
||||
# Check if the covariance is positive definite.
|
||||
eigenvalues = torch.linalg.eigvals(covariance)
|
||||
is_covariance_psd = bool(
|
||||
(covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all()
|
||||
)
|
||||
epsilon = 1e-9
|
||||
is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
|
||||
if is_covariance_psd:
|
||||
# If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
|
||||
distribution = torch.distributions.multivariate_normal.MultivariateNormal(
|
||||
mean_embeddings, covariance_matrix=1e-9 * covariance
|
||||
mean_embeddings, covariance_matrix=epsilon * covariance
|
||||
)
|
||||
new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
|
||||
sample_shape=(added_num_tokens,)
|
||||
|
Loading…
Reference in New Issue
Block a user