Chore: Remove asserts and replace with ValueErrors.

This commit is contained in:
Korbinian Poeppel 2025-07-01 15:56:32 +02:00
parent 7c239aaa1b
commit f8fcc6edcb

View File

@ -268,7 +268,8 @@ else:
], # all_states (matC_states (B, NH, (NC+1) * DHQK, DHHV), vecN_states (B, NH, (NC+1) * DHQK), scaMinter_states (B, NH, (NC+1))) ], # all_states (matC_states (B, NH, (NC+1) * DHQK, DHHV), vecN_states (B, NH, (NC+1) * DHQK), scaMinter_states (B, NH, (NC+1)))
]: ]:
B, NH, S, DHQK = q.shape B, NH, S, DHQK = q.shape
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}." if S % chunk_size != 0:
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
NC = S // chunk_size NC = S // chunk_size
vecI = i.view(B, NH, NC, chunk_size) vecI = i.view(B, NH, NC, chunk_size)
@ -349,12 +350,12 @@ else:
**kwargs, **kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
B, NH, S, DHQK = q.shape B, NH, S, DHQK = q.shape
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}." if S % chunk_size != 0:
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
NC = S // chunk_size NC = S // chunk_size
vecI = i.view(B, NH, NC, chunk_size) vecI = i.view(B, NH, NC, chunk_size)
vecF = f.view(B, NH, NC, chunk_size) vecF = f.view(B, NH, NC, chunk_size)
assert 0
# compute the gates, the g and the a and b vectors # compute the gates, the g and the a and b vectors
vecF_logsig = F.logsigmoid(vecF) vecF_logsig = F.logsigmoid(vecF)
@ -426,25 +427,31 @@ else:
B, NH, DHQK = q.shape B, NH, DHQK = q.shape
_, _, DHHV = v.shape _, _, DHHV = v.shape
assert q.shape == k.shape, "q and k must have the same shape" if q.shape != k.shape:
assert matC_old.shape == ( raise ValueError("q and k must have the same shape")
if matC_old.shape != (
B, B,
NH, NH,
DHQK, DHQK,
DHHV, DHHV,
), f"matC_old has wrong shape, got {matC_old.shape}" ):
assert vecN_old.shape == ( raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
if vecN_old.shape != (
B, B,
NH, NH,
DHQK, DHQK,
), f"vecN_old has wrong shape, got {vecN_old.shape}" ):
assert scaM_old.shape == ( raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
if scaM_old.shape != (
B, B,
NH, NH,
1, 1,
), f"scaM_old has wrong shape, got {scaM_old.shape}" ):
assert i.shape == (B, NH, 1), f"scaI has wrong shape, got {i.shape}" raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
assert f.shape == (B, NH, 1), f"scaF has wrong shape, got {f.shape}" if i.shape != (B, NH, 1):
raise ValueError(f"scaI has wrong shape, got {i.shape}")
if f.shape != (B, NH, 1):
raise ValueError(f"scaF has wrong shape, got {f.shape}")
# gates # gates
scaF_log = torch.nn.functional.logsigmoid(f) scaF_log = torch.nn.functional.logsigmoid(f)
@ -728,7 +735,8 @@ else:
h_out = torch.concatenate(h_outs, dim=2) h_out = torch.concatenate(h_outs, dim=2)
else: else:
assert S == 1, f"Received empty sequence (S={S}), require at least single element in the sequence." if S != 1:
raise ValueError(f"Received empty sequence (S={S}), require at least single element in the sequence.")
# process the sequence length in a single step # process the sequence length in a single step
# while this case is also captured by the regular mode above, # while this case is also captured by the regular mode above,
# it avoids the overhead of the loop and calls the step kernel directly # it avoids the overhead of the loop and calls the step kernel directly
@ -1004,8 +1012,10 @@ else:
x: torch.Tensor, # (B, S, NH, DH) x: torch.Tensor, # (B, S, NH, DH)
) -> torch.Tensor: # (B, S, NH * DH) ) -> torch.Tensor: # (B, S, NH * DH)
B, S, NH, DH = x.shape B, S, NH, DH = x.shape
assert NH == self.num_heads, f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}" if NH != self.num_heads:
assert DH == self.head_dim, f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}" raise ValueError(f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}")
if DH != self.head_dim:
raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
x = self._layer_normalize(x) x = self._layer_normalize(x)
x = x.reshape(B, S, -1) x = x.reshape(B, S, -1)
@ -1131,7 +1141,8 @@ else:
def forward( def forward(
self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]: ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}" if x.ndim != 3:
raise ValueError(f"Input must have shape [B, S, D], got {x.shape}")
B, S, _ = x.shape B, S, _ = x.shape
if self.config.weight_mode == "single": if self.config.weight_mode == "single":
q = self.q(x) q = self.q(x)
@ -1182,7 +1193,8 @@ else:
S, S,
self.v_dim // self.config.num_heads, self.v_dim // self.config.num_heads,
) )
assert h.shape == expected_h_shape, f"Got {h.shape}, expected {expected_h_shape}" if h.shape != expected_h_shape:
raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
h = h.transpose(1, 2) h = h.transpose(1, 2)
h_norm = self.multihead_norm(h) h_norm = self.multihead_norm(h)