TF: serializable hubert (#20966)

* serializable hubert
This commit is contained in:
Joao Gante 2023-01-17 13:07:37 +00:00 committed by GitHub
parent e5dcceb82c
commit 7f3dab39b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 8 deletions

View File

@ -221,13 +221,17 @@ def _compute_mask_indices(
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
tf.debugging.assert_less(
mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
)
),
)
# compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)

View File

@ -262,13 +262,17 @@ def _compute_mask_indices(
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
tf.debugging.assert_less(
mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
)
),
)
# compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)