From 9e4ea2517504f144e74d9a356d58e9f2be32b3fa Mon Sep 17 00:00:00 2001 From: Weston King-Leatham <71475274+WestonKing-Leatham@users.noreply.github.com> Date: Thu, 21 Oct 2021 07:27:32 -0400 Subject: [PATCH] Change asserts in src/transformers/models/xlnet/ to raise ValueError (#14088) * Change asserts in src/transformers/models/xlnet/ to raise ValueError * Update src/transformers/models/xlnet/modeling_tf_xlnet.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/xlnet/configuration_xlnet.py | 10 ++++++---- src/transformers/models/xlnet/modeling_tf_xlnet.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py index 1a87bcd9f44..131e867ff78 100644 --- a/src/transformers/models/xlnet/configuration_xlnet.py +++ b/src/transformers/models/xlnet/configuration_xlnet.py @@ -180,11 +180,13 @@ class XLNetConfig(PretrainedConfig): self.d_model = d_model self.n_layer = n_layer self.n_head = n_head - assert d_model % n_head == 0 + if d_model % n_head != 0: + raise ValueError(f"'d_model % n_head' ({d_model % n_head}) should be equal to 0") if "d_head" in kwargs: - assert ( - kwargs["d_head"] == d_model // n_head - ), f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})" + if kwargs["d_head"] != d_model // n_head: + raise ValueError( + f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})" + ) self.d_head = d_model // n_head self.ff_activation = ff_activation self.d_inner = d_inner diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index 859f7f8dce5..71a7acd5dde 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -561,7 +561,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) if bsz is not None: - assert bsz % 2 == 0, f"With bi_data, the batch size {bsz} should be divisible by 2" + if bsz % 2 != 0: + raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2") fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) else: