Change asserts in src/transformers/models/xlnet/ to raise ValueError (#14088)

* Change asserts in src/transformers/models/xlnet/ to raise ValueError

* Update src/transformers/models/xlnet/modeling_tf_xlnet.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Weston King-Leatham 2021-10-21 07:27:32 -04:00 committed by GitHub
parent e9d2a639f4
commit 9e4ea25175
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 5 deletions

View File

@ -180,11 +180,13 @@ class XLNetConfig(PretrainedConfig):
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
assert d_model % n_head == 0
if d_model % n_head != 0:
raise ValueError(f"'d_model % n_head' ({d_model % n_head}) should be equal to 0")
if "d_head" in kwargs:
assert (
kwargs["d_head"] == d_model // n_head
), f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})"
if kwargs["d_head"] != d_model // n_head:
raise ValueError(
f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})"
)
self.d_head = d_model // n_head
self.ff_activation = ff_activation
self.d_inner = d_inner

View File

@ -561,7 +561,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
if bsz is not None:
assert bsz % 2 == 0, f"With bi_data, the batch size {bsz} should be divisible by 2"
if bsz % 2 != 0:
raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else: