fix: replace asserts by error (#13894)

This commit is contained in:
Siarhei Melnik 2021-10-06 01:08:48 +03:00 committed by GitHub
parent f099249cf1
commit 7af7d7ce05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -172,7 +172,8 @@ class FlaxMultiHeadSelfAttention(nn.Module):
self.dim = self.config.dim
self.dropout = nn.Dropout(rate=self.config.attention_dropout)
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
if not (self.dim % self.n_heads == 0):
raise ValueError(f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}")
self.q_lin = nn.Dense(
self.dim,