mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refactor TFP call to just sigmoid() (#29641)
* Refactor TFP call to just sigmoid() * Make sure we cast to the right dtype
This commit is contained in:
parent
a7e5e15472
commit
31d01150ad
@ -5,7 +5,6 @@ import numpy as np
|
||||
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_tensorflow_probability_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
requires_backends,
|
||||
@ -21,9 +20,8 @@ if is_torch_available():
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
if is_tf_available() and is_tensorflow_probability_available():
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from ..models.auto.modeling_tf_auto import (
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
@ -249,8 +247,9 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
all_logits.append(logits)
|
||||
|
||||
dist_per_token = tfp.distributions.Bernoulli(logits=logits)
|
||||
probabilities = dist_per_token.probs_parameter() * tf.cast(attention_mask_example, tf.float32)
|
||||
probabilities = tf.math.sigmoid(tf.cast(logits, tf.float32)) * tf.cast(
|
||||
attention_mask_example, tf.float32
|
||||
)
|
||||
|
||||
coords_to_probs = collections.defaultdict(list)
|
||||
token_type_ids_example = token_type_ids_example
|
||||
|
Loading…
Reference in New Issue
Block a user