mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
c90e6e9625
commit
aa3778afc2
@ -1346,8 +1346,8 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
||||
if hasattr(model_kwargs, "attention_mask"):
|
||||
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
|
||||
else:
|
||||
position_ids = torch.range(
|
||||
0, conditioning_embeds.shape[1] - 1, dtype=torch.long, device=conditioning_embeds.device
|
||||
position_ids = torch.arange(
|
||||
0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
|
||||
|
||||
|
@ -100,7 +100,7 @@ class ConvNextV2GRN(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Compute and normalize global spatial feature maps
|
||||
global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True)
|
||||
global_features = torch.linalg.norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
|
||||
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
|
||||
|
||||
|
@ -429,7 +429,7 @@ class JukeboxBottleneckBlock(nn.Module):
|
||||
entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
|
||||
used_curr = (_codebook_elem >= self.threshold).sum()
|
||||
usage = torch.sum(usage)
|
||||
dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
|
||||
dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
|
||||
return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
|
||||
|
||||
def preprocess(self, hidden_states):
|
||||
@ -437,11 +437,13 @@ class JukeboxBottleneckBlock(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if hidden_states.shape[-1] == self.codebook_width:
|
||||
prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape))
|
||||
prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(
|
||||
np.prod(hidden_states.shape)
|
||||
)
|
||||
elif hidden_states.shape[-1] == 2 * self.codebook_width:
|
||||
x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
|
||||
prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
|
||||
torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
|
||||
prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
|
||||
torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
|
||||
)
|
||||
|
||||
# Normalise
|
||||
@ -517,7 +519,9 @@ class JukeboxBottleneckBlock(nn.Module):
|
||||
update_metrics = {}
|
||||
|
||||
# Loss
|
||||
commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape)
|
||||
commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(
|
||||
hidden_states.shape
|
||||
)
|
||||
|
||||
# Passthrough
|
||||
dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
|
||||
|
Loading…
Reference in New Issue
Block a user