mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Change asserts in src/transformers/models/xlnet/ to raise ValueError (#14088)
* Change asserts in src/transformers/models/xlnet/ to raise ValueError * Update src/transformers/models/xlnet/modeling_tf_xlnet.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
e9d2a639f4
commit
9e4ea25175
@ -180,11 +180,13 @@ class XLNetConfig(PretrainedConfig):
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
assert d_model % n_head == 0
|
||||
if d_model % n_head != 0:
|
||||
raise ValueError(f"'d_model % n_head' ({d_model % n_head}) should be equal to 0")
|
||||
if "d_head" in kwargs:
|
||||
assert (
|
||||
kwargs["d_head"] == d_model // n_head
|
||||
), f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})"
|
||||
if kwargs["d_head"] != d_model // n_head:
|
||||
raise ValueError(
|
||||
f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})"
|
||||
)
|
||||
self.d_head = d_model // n_head
|
||||
self.ff_activation = ff_activation
|
||||
self.d_inner = d_inner
|
||||
|
@ -561,7 +561,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
|
||||
|
||||
if bsz is not None:
|
||||
assert bsz % 2 == 0, f"With bi_data, the batch size {bsz} should be divisible by 2"
|
||||
if bsz % 2 != 0:
|
||||
raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
|
||||
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
|
||||
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user