mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Refactor TFP call to just sigmoid() (#29641)
* Refactor TFP call to just sigmoid() * Make sure we cast to the right dtype
This commit is contained in:
parent
a7e5e15472
commit
31d01150ad
@ -5,7 +5,6 @@ import numpy as np
|
|||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
is_tensorflow_probability_available,
|
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
@ -21,9 +20,8 @@ if is_torch_available():
|
|||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available() and is_tensorflow_probability_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_probability as tfp
|
|
||||||
|
|
||||||
from ..models.auto.modeling_tf_auto import (
|
from ..models.auto.modeling_tf_auto import (
|
||||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||||
@ -249,8 +247,9 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
|
|
||||||
all_logits.append(logits)
|
all_logits.append(logits)
|
||||||
|
|
||||||
dist_per_token = tfp.distributions.Bernoulli(logits=logits)
|
probabilities = tf.math.sigmoid(tf.cast(logits, tf.float32)) * tf.cast(
|
||||||
probabilities = dist_per_token.probs_parameter() * tf.cast(attention_mask_example, tf.float32)
|
attention_mask_example, tf.float32
|
||||||
|
)
|
||||||
|
|
||||||
coords_to_probs = collections.defaultdict(list)
|
coords_to_probs = collections.defaultdict(list)
|
||||||
token_type_ids_example = token_type_ids_example
|
token_type_ids_example = token_type_ids_example
|
||||||
|
Loading…
Reference in New Issue
Block a user