Chore: Remove asserts and replace with ValueErrors.

This commit is contained in:
Korbinian Poeppel 2025-07-01 15:56:32 +02:00
parent 7c239aaa1b
commit f8fcc6edcb

View File

@ -268,7 +268,8 @@ else:
], # all_states (matC_states (B, NH, (NC+1) * DHQK, DHHV), vecN_states (B, NH, (NC+1) * DHQK), scaMinter_states (B, NH, (NC+1)))
]:
B, NH, S, DHQK = q.shape
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}."
if S % chunk_size != 0:
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
NC = S // chunk_size
vecI = i.view(B, NH, NC, chunk_size)
@ -349,12 +350,12 @@ else:
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
B, NH, S, DHQK = q.shape
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}."
if S % chunk_size != 0:
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
NC = S // chunk_size
vecI = i.view(B, NH, NC, chunk_size)
vecF = f.view(B, NH, NC, chunk_size)
assert 0
# compute the gates, the g and the a and b vectors
vecF_logsig = F.logsigmoid(vecF)
@ -426,25 +427,31 @@ else:
B, NH, DHQK = q.shape
_, _, DHHV = v.shape
assert q.shape == k.shape, "q and k must have the same shape"
assert matC_old.shape == (
if q.shape != k.shape:
raise ValueError("q and k must have the same shape")
if matC_old.shape != (
B,
NH,
DHQK,
DHHV,
), f"matC_old has wrong shape, got {matC_old.shape}"
assert vecN_old.shape == (
):
raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
if vecN_old.shape != (
B,
NH,
DHQK,
), f"vecN_old has wrong shape, got {vecN_old.shape}"
assert scaM_old.shape == (
):
raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
if scaM_old.shape != (
B,
NH,
1,
), f"scaM_old has wrong shape, got {scaM_old.shape}"
assert i.shape == (B, NH, 1), f"scaI has wrong shape, got {i.shape}"
assert f.shape == (B, NH, 1), f"scaF has wrong shape, got {f.shape}"
):
raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
if i.shape != (B, NH, 1):
raise ValueError(f"scaI has wrong shape, got {i.shape}")
if f.shape != (B, NH, 1):
raise ValueError(f"scaF has wrong shape, got {f.shape}")
# gates
scaF_log = torch.nn.functional.logsigmoid(f)
@ -728,7 +735,8 @@ else:
h_out = torch.concatenate(h_outs, dim=2)
else:
assert S == 1, f"Received empty sequence (S={S}), require at least single element in the sequence."
if S != 1:
raise ValueError(f"Received empty sequence (S={S}), require at least single element in the sequence.")
# process the sequence length in a single step
# while this case is also captured by the regular mode above,
# it avoids the overhead of the loop and calls the step kernel directly
@ -1004,8 +1012,10 @@ else:
x: torch.Tensor, # (B, S, NH, DH)
) -> torch.Tensor: # (B, S, NH * DH)
B, S, NH, DH = x.shape
assert NH == self.num_heads, f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}"
assert DH == self.head_dim, f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}"
if NH != self.num_heads:
raise ValueError(f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}")
if DH != self.head_dim:
raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
x = self._layer_normalize(x)
x = x.reshape(B, S, -1)
@ -1131,7 +1141,8 @@ else:
def forward(
self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
if x.ndim != 3:
raise ValueError(f"Input must have shape [B, S, D], got {x.shape}")
B, S, _ = x.shape
if self.config.weight_mode == "single":
q = self.q(x)
@ -1182,7 +1193,8 @@ else:
S,
self.v_dim // self.config.num_heads,
)
assert h.shape == expected_h_shape, f"Got {h.shape}, expected {expected_h_shape}"
if h.shape != expected_h_shape:
raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
h = h.transpose(1, 2)
h_norm = self.multihead_norm(h)