mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Raise exceptions instead of asserts in src/transformers/models/bart/modeling_flax_[bart, marian, mbart, pegasus].py (#13939)
* Raise exceptions instead of asserts * fix: fixed failing quality check with copies * fix: fixed max line length * rerun github ci, failed to install dependencies
This commit is contained in:
parent
7fb2a8b3d9
commit
b65c389769
@ -237,9 +237,11 @@ class FlaxBartAttention(nn.Module):
|
||||
|
||||
def setup(self) -> None:
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
|
@ -241,9 +241,11 @@ class FlaxMarianAttention(nn.Module):
|
||||
|
||||
def setup(self) -> None:
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
|
@ -248,9 +248,11 @@ class FlaxMBartAttention(nn.Module):
|
||||
|
||||
def setup(self) -> None:
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
|
@ -241,9 +241,11 @@ class FlaxPegasusAttention(nn.Module):
|
||||
|
||||
def setup(self) -> None:
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
|
Loading…
Reference in New Issue
Block a user