mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
e5dcceb82c
commit
7f3dab39b5
@ -221,13 +221,17 @@ def _compute_mask_indices(
|
||||
if mask_length < 1:
|
||||
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||
|
||||
if mask_length > sequence_length:
|
||||
raise ValueError(
|
||||
tf.debugging.assert_less(
|
||||
mask_length,
|
||||
sequence_length,
|
||||
message=(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
|
||||
f" `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
|
||||
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
|
||||
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
|
||||
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
|
||||
|
||||
|
@ -262,13 +262,17 @@ def _compute_mask_indices(
|
||||
if mask_length < 1:
|
||||
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||
|
||||
if mask_length > sequence_length:
|
||||
raise ValueError(
|
||||
tf.debugging.assert_less(
|
||||
mask_length,
|
||||
sequence_length,
|
||||
message=(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
|
||||
f" `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
|
||||
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
|
||||
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
|
||||
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user