Change deprecated PT functions (#37041)

Change deprecated functions
This commit is contained in:
cyyever 2025-03-28 22:26:22 +08:00 committed by GitHub
parent c90e6e9625
commit aa3778afc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 8 deletions

View File

@ -1346,8 +1346,8 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
if hasattr(model_kwargs, "attention_mask"): if hasattr(model_kwargs, "attention_mask"):
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1 position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
else: else:
position_ids = torch.range( position_ids = torch.arange(
0, conditioning_embeds.shape[1] - 1, dtype=torch.long, device=conditioning_embeds.device 0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
) )
position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1) position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)

View File

@ -100,7 +100,7 @@ class ConvNextV2GRN(nn.Module):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# Compute and normalize global spatial feature maps # Compute and normalize global spatial feature maps
global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True) global_features = torch.linalg.norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6) norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states

View File

@ -429,7 +429,7 @@ class JukeboxBottleneckBlock(nn.Module):
entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
used_curr = (_codebook_elem >= self.threshold).sum() used_curr = (_codebook_elem >= self.threshold).sum()
usage = torch.sum(usage) usage = torch.sum(usage)
dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk} return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
def preprocess(self, hidden_states): def preprocess(self, hidden_states):
@ -437,11 +437,13 @@ class JukeboxBottleneckBlock(nn.Module):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if hidden_states.shape[-1] == self.codebook_width: if hidden_states.shape[-1] == self.codebook_width:
prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape)) prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(
np.prod(hidden_states.shape)
)
elif hidden_states.shape[-1] == 2 * self.codebook_width: elif hidden_states.shape[-1] == 2 * self.codebook_width:
x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :] x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
) )
# Normalise # Normalise
@ -517,7 +519,9 @@ class JukeboxBottleneckBlock(nn.Module):
update_metrics = {} update_metrics = {}
# Loss # Loss
commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape) commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(
hidden_states.shape
)
# Passthrough # Passthrough
dequantised_states = hidden_states + (dequantised_states - hidden_states).detach() dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()