mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Chore: Remove asserts and replace with ValueErrors.
This commit is contained in:
parent
7c239aaa1b
commit
f8fcc6edcb
@ -268,7 +268,8 @@ else:
|
||||
], # all_states (matC_states (B, NH, (NC+1) * DHQK, DHHV), vecN_states (B, NH, (NC+1) * DHQK), scaMinter_states (B, NH, (NC+1)))
|
||||
]:
|
||||
B, NH, S, DHQK = q.shape
|
||||
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}."
|
||||
if S % chunk_size != 0:
|
||||
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
|
||||
NC = S // chunk_size
|
||||
|
||||
vecI = i.view(B, NH, NC, chunk_size)
|
||||
@ -349,12 +350,12 @@ else:
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
|
||||
B, NH, S, DHQK = q.shape
|
||||
assert S % chunk_size == 0, f"Sequence length {S} is not divisible by chunk size {chunk_size}."
|
||||
if S % chunk_size != 0:
|
||||
raise ValueError(f"Sequence length {S} is not divisible by chunk size {chunk_size}.")
|
||||
NC = S // chunk_size
|
||||
|
||||
vecI = i.view(B, NH, NC, chunk_size)
|
||||
vecF = f.view(B, NH, NC, chunk_size)
|
||||
assert 0
|
||||
|
||||
# compute the gates, the g and the a and b vectors
|
||||
vecF_logsig = F.logsigmoid(vecF)
|
||||
@ -426,25 +427,31 @@ else:
|
||||
|
||||
B, NH, DHQK = q.shape
|
||||
_, _, DHHV = v.shape
|
||||
assert q.shape == k.shape, "q and k must have the same shape"
|
||||
assert matC_old.shape == (
|
||||
if q.shape != k.shape:
|
||||
raise ValueError("q and k must have the same shape")
|
||||
if matC_old.shape != (
|
||||
B,
|
||||
NH,
|
||||
DHQK,
|
||||
DHHV,
|
||||
), f"matC_old has wrong shape, got {matC_old.shape}"
|
||||
assert vecN_old.shape == (
|
||||
):
|
||||
raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
|
||||
if vecN_old.shape != (
|
||||
B,
|
||||
NH,
|
||||
DHQK,
|
||||
), f"vecN_old has wrong shape, got {vecN_old.shape}"
|
||||
assert scaM_old.shape == (
|
||||
):
|
||||
raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
|
||||
if scaM_old.shape != (
|
||||
B,
|
||||
NH,
|
||||
1,
|
||||
), f"scaM_old has wrong shape, got {scaM_old.shape}"
|
||||
assert i.shape == (B, NH, 1), f"scaI has wrong shape, got {i.shape}"
|
||||
assert f.shape == (B, NH, 1), f"scaF has wrong shape, got {f.shape}"
|
||||
):
|
||||
raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
|
||||
if i.shape != (B, NH, 1):
|
||||
raise ValueError(f"scaI has wrong shape, got {i.shape}")
|
||||
if f.shape != (B, NH, 1):
|
||||
raise ValueError(f"scaF has wrong shape, got {f.shape}")
|
||||
|
||||
# gates
|
||||
scaF_log = torch.nn.functional.logsigmoid(f)
|
||||
@ -728,7 +735,8 @@ else:
|
||||
h_out = torch.concatenate(h_outs, dim=2)
|
||||
|
||||
else:
|
||||
assert S == 1, f"Received empty sequence (S={S}), require at least single element in the sequence."
|
||||
if S != 1:
|
||||
raise ValueError(f"Received empty sequence (S={S}), require at least single element in the sequence.")
|
||||
# process the sequence length in a single step
|
||||
# while this case is also captured by the regular mode above,
|
||||
# it avoids the overhead of the loop and calls the step kernel directly
|
||||
@ -1004,8 +1012,10 @@ else:
|
||||
x: torch.Tensor, # (B, S, NH, DH)
|
||||
) -> torch.Tensor: # (B, S, NH * DH)
|
||||
B, S, NH, DH = x.shape
|
||||
assert NH == self.num_heads, f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}"
|
||||
assert DH == self.head_dim, f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}"
|
||||
if NH != self.num_heads:
|
||||
raise ValueError(f"Expected {self.num_heads} heads, got {NH}, input shape: {x.shape}")
|
||||
if DH != self.head_dim:
|
||||
raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
|
||||
|
||||
x = self._layer_normalize(x)
|
||||
x = x.reshape(B, S, -1)
|
||||
@ -1131,7 +1141,8 @@ else:
|
||||
def forward(
|
||||
self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
|
||||
) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
|
||||
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
|
||||
if x.ndim != 3:
|
||||
raise ValueError(f"Input must have shape [B, S, D], got {x.shape}")
|
||||
B, S, _ = x.shape
|
||||
if self.config.weight_mode == "single":
|
||||
q = self.q(x)
|
||||
@ -1182,7 +1193,8 @@ else:
|
||||
S,
|
||||
self.v_dim // self.config.num_heads,
|
||||
)
|
||||
assert h.shape == expected_h_shape, f"Got {h.shape}, expected {expected_h_shape}"
|
||||
if h.shape != expected_h_shape:
|
||||
raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
|
||||
|
||||
h = h.transpose(1, 2)
|
||||
h_norm = self.multihead_norm(h)
|
||||
|
Loading…
Reference in New Issue
Block a user