From f8fcc6edcbdac588ee4fb3a58647463c6a1498f3 Mon Sep 17 00:00:00 2001 From: Korbinian Poeppel Date: Tue, 1 Jul 2025 15:56:32 +0200 Subject: [PATCH] Chore: Remove asserts and replace with ValueErrors. --- .../models/xlstm/modeling_xlstm.py | 46 ++++++++++++------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 42a371e3681..d1e76ea3a01 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -268,7 +268,8 @@ else: ], # all_states (matC_states (B, NH, (NC+1) * DHQK, DHHV), vecN_states (B, NH, (NC+1) * DHQK), scaMinter_states (B, NH, (NC+1))) ]: B, NH, S, DHQK = q.shape - assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}." + if S % chunk_size != 0: + raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.") NC = S // chunk_size vecI = i.view(B, NH, NC, chunk_size) @@ -349,12 +350,12 @@ else: **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: B, NH, S, DHQK = q.shape - assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}." + if S % chunk_size != 0: + raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.") NC = S // chunk_size vecI = i.view(B, NH, NC, chunk_size) vecF = f.view(B, NH, NC, chunk_size) - assert 0 # compute the gates, the g and the a and b vectors vecF_logsig = F.logsigmoid(vecF) @@ -426,25 +427,31 @@ else: B, NH, DHQK = q.shape _, _, DHHV = v.shape - assert q.shape == k.shape, "q and k must have the same shape" - assert matC_old.shape == ( + if q.shape != k.shape: + raise ValueError("q and k must have the same shape") + if matC_old.shape != ( B, NH, DHQK, DHHV, - ), f"matC_old has wrong shape, got {matC_old.shape}" - assert vecN_old.shape == ( + ): + raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}") + if vecN_old.shape != ( B, NH, DHQK, - ), f"vecN_old has wrong shape, got {vecN_old.shape}" - assert scaM_old.shape == ( + ): + raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}") + if scaM_old.shape != ( B, NH, 1, - ), f"scaM_old has wrong shape, got {scaM_old.shape}" - assert i.shape == (B, NH, 1), f"scaI has wrong shape, got {i.shape}" - assert f.shape == (B, NH, 1), f"scaF has wrong shape, got {f.shape}" + ): + raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}") + if i.shape != (B, NH, 1): + raise ValueError(f"scaI has wrong shape, got {i.shape}") + if f.shape != (B, NH, 1): + raise ValueError(f"scaF has wrong shape, got {f.shape}") # gates scaF_log = torch.nn.functional.logsigmoid(f) @@ -728,7 +735,8 @@ else: h_out = torch.concatenate(h_outs, dim=2) else: - assert S == 1, f"Received empty sequence (S={S}), require at least single element in the sequence." + if S != 1: + raise ValueError(f"Received empty sequence (S={S}), require at least single element in the sequence.") # process the sequence length in a single step # while this case is also captured by the regular mode above, # it avoids the overhead of the loop and calls the step kernel directly @@ -1004,8 +1012,10 @@ else: x: torch.Tensor, # (B, S, NH, DH) ) -> torch.Tensor: # (B, S, NH * DH) B, S, NH, DH = x.shape - assert NH == self.num_heads, f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}" - assert DH == self.head_dim, f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}" + if NH != self.num_heads: + raise ValueError(f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}") + if DH != self.head_dim: + raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}") x = self._layer_normalize(x) x = x.reshape(B, S, -1) @@ -1131,7 +1141,8 @@ else: def forward( self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]: - assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}" + if x.ndim != 3: + raise ValueError(f"Input must have shape [B, S, D], got {x.shape}") B, S, _ = x.shape if self.config.weight_mode == "single": q = self.q(x) @@ -1182,7 +1193,8 @@ else: S, self.v_dim // self.config.num_heads, ) - assert h.shape == expected_h_shape, f"Got {h.shape}, expected {expected_h_shape}" + if h.shape != expected_h_shape: + raise ValueError(f"Got {h.shape}, expected {expected_h_shape}") h = h.transpose(1, 2) h_norm = self.multihead_norm(h)