Refactor TFP call to just sigmoid() (#29641)

* Refactor TFP call to just sigmoid()

* Make sure we cast to the right dtype
This commit is contained in:
Matt 2024-03-13 17:51:13 +00:00 committed by GitHub
parent a7e5e15472
commit 31d01150ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,6 @@ import numpy as np
from ..utils import (
add_end_docstrings,
is_tensorflow_probability_available,
is_tf_available,
is_torch_available,
requires_backends,
@ -21,9 +20,8 @@ if is_torch_available():
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
)
if is_tf_available() and is_tensorflow_probability_available():
if is_tf_available():
import tensorflow as tf
import tensorflow_probability as tfp
from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
@ -249,8 +247,9 @@ class TableQuestionAnsweringPipeline(Pipeline):
all_logits.append(logits)
dist_per_token = tfp.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs_parameter() * tf.cast(attention_mask_example, tf.float32)
probabilities = tf.math.sigmoid(tf.cast(logits, tf.float32)) * tf.cast(
attention_mask_example, tf.float32
)
coords_to_probs = collections.defaultdict(list)
token_type_ids_example = token_type_ids_example