mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
parent
c90e6e9625
commit
aa3778afc2
@ -1346,8 +1346,8 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
|||||||
if hasattr(model_kwargs, "attention_mask"):
|
if hasattr(model_kwargs, "attention_mask"):
|
||||||
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
|
||||||
else:
|
else:
|
||||||
position_ids = torch.range(
|
position_ids = torch.arange(
|
||||||
0, conditioning_embeds.shape[1] - 1, dtype=torch.long, device=conditioning_embeds.device
|
0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
|
position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class ConvNextV2GRN(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
# Compute and normalize global spatial feature maps
|
# Compute and normalize global spatial feature maps
|
||||||
global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True)
|
global_features = torch.linalg.norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
|
||||||
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
|
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
|
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
|
||||||
|
|
||||||
|
@ -429,7 +429,7 @@ class JukeboxBottleneckBlock(nn.Module):
|
|||||||
entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
|
entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
|
||||||
used_curr = (_codebook_elem >= self.threshold).sum()
|
used_curr = (_codebook_elem >= self.threshold).sum()
|
||||||
usage = torch.sum(usage)
|
usage = torch.sum(usage)
|
||||||
dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
|
dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
|
||||||
return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
|
return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
|
||||||
|
|
||||||
def preprocess(self, hidden_states):
|
def preprocess(self, hidden_states):
|
||||||
@ -437,11 +437,13 @@ class JukeboxBottleneckBlock(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
if hidden_states.shape[-1] == self.codebook_width:
|
if hidden_states.shape[-1] == self.codebook_width:
|
||||||
prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape))
|
prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(
|
||||||
|
np.prod(hidden_states.shape)
|
||||||
|
)
|
||||||
elif hidden_states.shape[-1] == 2 * self.codebook_width:
|
elif hidden_states.shape[-1] == 2 * self.codebook_width:
|
||||||
x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
|
x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
|
||||||
prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
|
prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
|
||||||
torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
|
torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalise
|
# Normalise
|
||||||
@ -517,7 +519,9 @@ class JukeboxBottleneckBlock(nn.Module):
|
|||||||
update_metrics = {}
|
update_metrics = {}
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape)
|
commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(
|
||||||
|
hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
# Passthrough
|
# Passthrough
|
||||||
dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
|
dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
|
||||||
|
Loading…
Reference in New Issue
Block a user