mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
v4.39 deprecations 🧼 (#29492)
This commit is contained in:
parent
979fccc90f
commit
ffe60fdcd6
@ -336,12 +336,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
- process
|
- process
|
||||||
- finalize
|
- finalize
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
[[autodoc]] top_k_top_p_filtering
|
|
||||||
|
|
||||||
[[autodoc]] tf_top_k_top_p_filtering
|
|
||||||
|
|
||||||
## Streamers
|
## Streamers
|
||||||
|
|
||||||
[[autodoc]] TextStreamer
|
[[autodoc]] TextStreamer
|
||||||
|
@ -335,12 +335,6 @@ generation_output[:2]
|
|||||||
- process
|
- process
|
||||||
- finalize
|
- finalize
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
[[autodoc]] top_k_top_p_filtering
|
|
||||||
|
|
||||||
[[autodoc]] tf_top_k_top_p_filtering
|
|
||||||
|
|
||||||
## Streamers
|
## Streamers
|
||||||
|
|
||||||
[[autodoc]] TextStreamer
|
[[autodoc]] TextStreamer
|
||||||
|
@ -330,12 +330,6 @@ generation_output[:2]
|
|||||||
- process
|
- process
|
||||||
- finalize
|
- finalize
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
[[autodoc]] top_k_top_p_filtering
|
|
||||||
|
|
||||||
[[autodoc]] tf_top_k_top_p_filtering
|
|
||||||
|
|
||||||
## Streamers
|
## Streamers
|
||||||
|
|
||||||
[[autodoc]] TextStreamer
|
[[autodoc]] TextStreamer
|
||||||
|
@ -1409,7 +1409,6 @@ else:
|
|||||||
"TypicalLogitsWarper",
|
"TypicalLogitsWarper",
|
||||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
"WhisperTimeStampLogitsProcessor",
|
"WhisperTimeStampLogitsProcessor",
|
||||||
"top_k_top_p_filtering",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["generation_utils"] = []
|
_import_structure["generation_utils"] = []
|
||||||
@ -3814,7 +3813,6 @@ else:
|
|||||||
"TFTemperatureLogitsWarper",
|
"TFTemperatureLogitsWarper",
|
||||||
"TFTopKLogitsWarper",
|
"TFTopKLogitsWarper",
|
||||||
"TFTopPLogitsWarper",
|
"TFTopPLogitsWarper",
|
||||||
"tf_top_k_top_p_filtering",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["generation_tf_utils"] = []
|
_import_structure["generation_tf_utils"] = []
|
||||||
@ -6206,7 +6204,6 @@ if TYPE_CHECKING:
|
|||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
WhisperTimeStampLogitsProcessor,
|
WhisperTimeStampLogitsProcessor,
|
||||||
top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .models.albert import (
|
from .models.albert import (
|
||||||
@ -8178,7 +8175,6 @@ if TYPE_CHECKING:
|
|||||||
TFTemperatureLogitsWarper,
|
TFTemperatureLogitsWarper,
|
||||||
TFTopKLogitsWarper,
|
TFTopKLogitsWarper,
|
||||||
TFTopPLogitsWarper,
|
TFTopPLogitsWarper,
|
||||||
tf_top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -138,14 +137,6 @@ class AccurateGELUActivation(nn.Module):
|
|||||||
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
||||||
|
|
||||||
|
|
||||||
class SiLUActivation(nn.SiLU):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
warnings.warn(
|
|
||||||
"The SiLUActivation class has been deprecated and will be removed in v4.39. Please use nn.SiLU instead.",
|
|
||||||
)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MishActivation(nn.Module):
|
class MishActivation(nn.Module):
|
||||||
"""
|
"""
|
||||||
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
||||||
|
@ -88,7 +88,6 @@ else:
|
|||||||
]
|
]
|
||||||
_import_structure["utils"] = [
|
_import_structure["utils"] = [
|
||||||
"GenerationMixin",
|
"GenerationMixin",
|
||||||
"top_k_top_p_filtering",
|
|
||||||
"GreedySearchEncoderDecoderOutput",
|
"GreedySearchEncoderDecoderOutput",
|
||||||
"GreedySearchDecoderOnlyOutput",
|
"GreedySearchDecoderOnlyOutput",
|
||||||
"SampleEncoderDecoderOutput",
|
"SampleEncoderDecoderOutput",
|
||||||
@ -130,7 +129,6 @@ else:
|
|||||||
]
|
]
|
||||||
_import_structure["tf_utils"] = [
|
_import_structure["tf_utils"] = [
|
||||||
"TFGenerationMixin",
|
"TFGenerationMixin",
|
||||||
"tf_top_k_top_p_filtering",
|
|
||||||
"TFGreedySearchDecoderOnlyOutput",
|
"TFGreedySearchDecoderOnlyOutput",
|
||||||
"TFGreedySearchEncoderDecoderOutput",
|
"TFGreedySearchEncoderDecoderOutput",
|
||||||
"TFSampleEncoderDecoderOutput",
|
"TFSampleEncoderDecoderOutput",
|
||||||
@ -241,7 +239,6 @@ if TYPE_CHECKING:
|
|||||||
GreedySearchEncoderDecoderOutput,
|
GreedySearchEncoderDecoderOutput,
|
||||||
SampleDecoderOnlyOutput,
|
SampleDecoderOnlyOutput,
|
||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -279,7 +276,6 @@ if TYPE_CHECKING:
|
|||||||
TFGreedySearchEncoderDecoderOutput,
|
TFGreedySearchEncoderDecoderOutput,
|
||||||
TFSampleDecoderOnlyOutput,
|
TFSampleDecoderOnlyOutput,
|
||||||
TFSampleEncoderDecoderOutput,
|
TFSampleEncoderDecoderOutput,
|
||||||
tf_top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -3088,68 +3088,6 @@ class TFGenerationMixin:
|
|||||||
return generated
|
return generated
|
||||||
|
|
||||||
|
|
||||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
|
||||||
"""
|
|
||||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: logits distribution shape (batch size, vocabulary size)
|
|
||||||
top_k (`int`, *optional*, defaults to 0):
|
|
||||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
|
||||||
top_p (`float`, *optional*, defaults to 1.0):
|
|
||||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
|
||||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
|
||||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
|
||||||
Minimumber of tokens we keep per batch example in the output.
|
|
||||||
|
|
||||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`tf_top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TFTopKLogitsWarper` and "
|
|
||||||
"`TFTopPLogitsWarper` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_shape = shape_list(logits)
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
|
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
|
||||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
|
||||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
|
||||||
sorted_logits = tf.gather(
|
|
||||||
logits, sorted_indices, axis=-1, batch_dims=1
|
|
||||||
) # expects logits to be of dim (batch_size, vocab_size)
|
|
||||||
|
|
||||||
cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)
|
|
||||||
|
|
||||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
|
|
||||||
if min_tokens_to_keep > 1:
|
|
||||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
|
||||||
sorted_indices_to_remove = tf.concat(
|
|
||||||
[
|
|
||||||
tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
|
|
||||||
sorted_indices_to_remove[:, min_tokens_to_keep:],
|
|
||||||
],
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Shift the indices to the right to keep also the first token above the threshold
|
|
||||||
sorted_indices_to_remove = tf.concat(
|
|
||||||
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, :-1]],
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
# scatter sorted tensors to original indexing
|
|
||||||
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
|
||||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_values_on_batch_indices(values, batch_indices):
|
def scatter_values_on_batch_indices(values, batch_indices):
|
||||||
shape = shape_list(batch_indices)
|
shape = shape_list(batch_indices)
|
||||||
# broadcast batch dim to shape
|
# broadcast batch dim to shape
|
||||||
|
@ -4810,47 +4810,6 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(
|
|
||||||
logits: torch.FloatTensor,
|
|
||||||
top_k: int = 0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
filter_value: float = -float("Inf"),
|
|
||||||
min_tokens_to_keep: int = 1,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
"""
|
|
||||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: logits distribution shape (batch size, vocabulary size)
|
|
||||||
top_k (`int`, *optional*, defaults to 0):
|
|
||||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
|
||||||
top_p (`float`, *optional*, defaults to 1.0):
|
|
||||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
|
||||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
|
||||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
|
||||||
Minimumber of tokens we keep per batch example in the output.
|
|
||||||
|
|
||||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"`top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TopKLogitsWarper` and `TopPLogitsWarper` "
|
|
||||||
"instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
|
||||||
None, logits
|
|
||||||
)
|
|
||||||
|
|
||||||
if 0 <= top_p <= 1.0:
|
|
||||||
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
|
||||||
None, logits
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _ranking_fast(
|
def _ranking_fast(
|
||||||
context_hidden: torch.FloatTensor,
|
context_hidden: torch.FloatTensor,
|
||||||
next_hidden: torch.FloatTensor,
|
next_hidden: torch.FloatTensor,
|
||||||
|
@ -129,10 +129,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
return self._cos_cached
|
return self._cos_cached
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids):
|
||||||
if seq_len is not None:
|
|
||||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
|
|
||||||
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
@ -151,17 +148,17 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids):
|
||||||
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||||
position_ids = position_ids.float() / self.scaling_factor
|
position_ids = position_ids.float() / self.scaling_factor
|
||||||
cos, sin = super().forward(x, position_ids, seq_len)
|
cos, sin = super().forward(x, position_ids)
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids):
|
||||||
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_position_embeddings:
|
if seq_len > self.max_position_embeddings:
|
||||||
@ -173,7 +170,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|||||||
)
|
)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||||
|
|
||||||
cos, sin = super().forward(x, position_ids, seq_len)
|
cos, sin = super().forward(x, position_ids)
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
|
@ -120,27 +120,10 @@ class OPTAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
|
self.num_heads = config.num_attention_heads
|
||||||
"""
|
self.dropout = config.attention_dropout
|
||||||
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
|
self.enable_bias = config.enable_bias
|
||||||
warning and return that value, otherwise take the equivalent config.config_arg_name
|
|
||||||
"""
|
|
||||||
val = None
|
|
||||||
if fn_arg_name in kwargs:
|
|
||||||
logging.warning(
|
|
||||||
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
|
|
||||||
"v4.39. Please set it in the config instead"
|
|
||||||
)
|
|
||||||
val = kwargs.pop(fn_arg_name)
|
|
||||||
else:
|
|
||||||
val = getattr(config, config_arg_name)
|
|
||||||
return val
|
|
||||||
|
|
||||||
self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
|
|
||||||
self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
|
|
||||||
self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
|
|
||||||
self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
|
|
||||||
|
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
@ -408,10 +408,6 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(*args, **kwargs):
|
|
||||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(metaclass=DummyObject):
|
class PreTrainedModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -128,10 +128,6 @@ class TFTopPLogitsWarper(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
def tf_top_k_top_p_filtering(*args, **kwargs):
|
|
||||||
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
|
||||||
|
|
||||||
|
|
||||||
class KerasMetricCallback(metaclass=DummyObject):
|
class KerasMetricCallback(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
@ -41,7 +41,6 @@ if is_tf_available():
|
|||||||
TFBartForConditionalGeneration,
|
TFBartForConditionalGeneration,
|
||||||
TFLogitsProcessorList,
|
TFLogitsProcessorList,
|
||||||
TFMinLengthLogitsProcessor,
|
TFMinLengthLogitsProcessor,
|
||||||
tf_top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
from transformers.modeling_tf_utils import keras
|
from transformers.modeling_tf_utils import keras
|
||||||
|
|
||||||
@ -49,102 +48,6 @@ if is_tensorflow_text_available():
|
|||||||
import tensorflow_text as text
|
import tensorflow_text as text
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
|
||||||
# tests whether the top_k_top_p_filtering function behaves as expected
|
|
||||||
def test_top_k_top_p_filtering(self):
|
|
||||||
logits = tf.convert_to_tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
8.2220991, # 3rd highest value; idx. 0
|
|
||||||
-0.5620044,
|
|
||||||
5.23229752,
|
|
||||||
4.0386393,
|
|
||||||
-6.8798378,
|
|
||||||
-0.54785802,
|
|
||||||
-3.2012153,
|
|
||||||
2.92777176,
|
|
||||||
1.88171953,
|
|
||||||
7.35341276, # 5th highest value; idx. 9
|
|
||||||
8.43207833, # 2nd highest value; idx. 10
|
|
||||||
-9.85711836,
|
|
||||||
-5.96209236,
|
|
||||||
-1.13039161,
|
|
||||||
-7.1115294,
|
|
||||||
-0.8369633,
|
|
||||||
-5.3186408,
|
|
||||||
7.06427407,
|
|
||||||
0.81369344,
|
|
||||||
-0.82023817,
|
|
||||||
-5.9179796,
|
|
||||||
0.58813443,
|
|
||||||
-6.99778438,
|
|
||||||
4.71551189,
|
|
||||||
-0.18771637,
|
|
||||||
7.44020759, # 4th highest value; idx. 25
|
|
||||||
9.38450987, # 1st highest value; idx. 26
|
|
||||||
2.12662941,
|
|
||||||
-9.32562038,
|
|
||||||
2.35652522,
|
|
||||||
], # cummulative prob of 5 highest values <= 0.6
|
|
||||||
[
|
|
||||||
0.58425518,
|
|
||||||
4.53139238,
|
|
||||||
-5.57510464,
|
|
||||||
-6.28030699,
|
|
||||||
-7.19529503,
|
|
||||||
-4.02122551,
|
|
||||||
1.39337037,
|
|
||||||
-6.06707057,
|
|
||||||
1.59480517,
|
|
||||||
-9.643119,
|
|
||||||
0.03907799,
|
|
||||||
0.67231762,
|
|
||||||
-8.88206726,
|
|
||||||
6.27115922, # 4th highest value; idx. 13
|
|
||||||
2.28520723,
|
|
||||||
4.82767506,
|
|
||||||
4.30421368,
|
|
||||||
8.8275313, # 2nd highest value; idx. 17
|
|
||||||
5.44029958, # 5th highest value; idx. 18
|
|
||||||
-4.4735794,
|
|
||||||
7.38579536, # 3rd highest value; idx. 20
|
|
||||||
-2.91051663,
|
|
||||||
2.61946077,
|
|
||||||
-2.5674762,
|
|
||||||
-9.48959302,
|
|
||||||
-4.02922645,
|
|
||||||
-1.35416918,
|
|
||||||
9.67702323, # 1st highest value; idx. 27
|
|
||||||
-5.89478553,
|
|
||||||
1.85370467,
|
|
||||||
], # cummulative prob of 5 highest values <= 0.6
|
|
||||||
],
|
|
||||||
dtype=tf.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
non_inf_expected_idx = tf.convert_to_tensor(
|
|
||||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
|
||||||
dtype=tf.int32,
|
|
||||||
) # expected non filtered idx as noted above
|
|
||||||
|
|
||||||
non_inf_expected_output = tf.convert_to_tensor(
|
|
||||||
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
|
|
||||||
dtype=tf.float32,
|
|
||||||
) # expected non filtered values as noted above
|
|
||||||
|
|
||||||
output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
|
||||||
|
|
||||||
non_inf_output = output[output != -float("inf")]
|
|
||||||
non_inf_idx = tf.cast(
|
|
||||||
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))),
|
|
||||||
dtype=tf.int32,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
|
||||||
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||||
|
@ -52,7 +52,6 @@ if is_torch_available():
|
|||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
ImageGPTForCausalImageModeling,
|
ImageGPTForCausalImageModeling,
|
||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
top_k_top_p_filtering,
|
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
@ -2345,133 +2344,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
# tests whether the top_k_top_p function behaves as expected
|
|
||||||
def test_top_k_top_p_filtering(self):
|
|
||||||
logits = torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
8.2220991, # 3rd highest value; idx. 0
|
|
||||||
-0.5620044,
|
|
||||||
5.23229752,
|
|
||||||
4.0386393,
|
|
||||||
-6.8798378,
|
|
||||||
-0.54785802,
|
|
||||||
-3.2012153,
|
|
||||||
2.92777176,
|
|
||||||
1.88171953,
|
|
||||||
7.35341276,
|
|
||||||
8.43207833, # 2nd highest value; idx. 10
|
|
||||||
-9.85711836,
|
|
||||||
-5.96209236,
|
|
||||||
-1.13039161,
|
|
||||||
-7.1115294,
|
|
||||||
-0.8369633,
|
|
||||||
-5.3186408,
|
|
||||||
7.06427407,
|
|
||||||
0.81369344,
|
|
||||||
-0.82023817,
|
|
||||||
-5.9179796,
|
|
||||||
0.58813443,
|
|
||||||
-6.99778438,
|
|
||||||
4.71551189,
|
|
||||||
-0.18771637,
|
|
||||||
7.44020759, # 4th highest value; idx. 25
|
|
||||||
9.38450987, # 1st highest value; idx. 26
|
|
||||||
2.12662941,
|
|
||||||
-9.32562038,
|
|
||||||
2.35652522,
|
|
||||||
], # cummulative prob of 4 highest values <= 0.6
|
|
||||||
[
|
|
||||||
0.58425518,
|
|
||||||
4.53139238,
|
|
||||||
-5.57510464,
|
|
||||||
-6.28030699,
|
|
||||||
-7.19529503,
|
|
||||||
-4.02122551,
|
|
||||||
1.39337037,
|
|
||||||
-6.06707057,
|
|
||||||
1.59480517,
|
|
||||||
-9.643119,
|
|
||||||
0.03907799,
|
|
||||||
0.67231762,
|
|
||||||
-8.88206726,
|
|
||||||
6.27115922, # 4th highest value; idx. 13
|
|
||||||
2.28520723,
|
|
||||||
4.82767506,
|
|
||||||
4.30421368,
|
|
||||||
8.8275313, # 2nd highest value; idx. 17
|
|
||||||
5.44029958,
|
|
||||||
-4.4735794,
|
|
||||||
7.38579536, # 3rd highest value; idx. 20
|
|
||||||
-2.91051663,
|
|
||||||
2.61946077,
|
|
||||||
-2.5674762,
|
|
||||||
-9.48959302,
|
|
||||||
-4.02922645,
|
|
||||||
-1.35416918,
|
|
||||||
9.67702323, # 1st highest value; idx. 27
|
|
||||||
-5.89478553,
|
|
||||||
1.85370467,
|
|
||||||
], # cummulative prob of 4 highest values <= 0.6
|
|
||||||
],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=torch_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
non_inf_expected_idx = torch.tensor(
|
|
||||||
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=torch_device,
|
|
||||||
) # expected non filtered idx as noted above
|
|
||||||
|
|
||||||
non_inf_expected_output = torch.tensor(
|
|
||||||
[
|
|
||||||
8.2221,
|
|
||||||
8.4321,
|
|
||||||
7.4402,
|
|
||||||
9.3845,
|
|
||||||
6.2712,
|
|
||||||
8.8275,
|
|
||||||
7.3858,
|
|
||||||
9.6770,
|
|
||||||
], # expected non filtered values as noted above
|
|
||||||
dtype=torch.float,
|
|
||||||
device=torch_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
|
||||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
|
||||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
|
||||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
|
||||||
|
|
||||||
# tests whether the function uses filter_value instead of default -inf
|
|
||||||
def test_top_k_top_p_filtering_with_filter_value(self):
|
|
||||||
logits = torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0.99, # get filtered by top-p filtering
|
|
||||||
0.98, # get filtered by top-k filtering
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=torch_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_output = torch.tensor(
|
|
||||||
[[1, 1, 1, 0, 0]],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=torch_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
|
||||||
|
|
||||||
def test_speculative_sampling(self):
|
def test_speculative_sampling(self):
|
||||||
# assume vocab size 10, input length 5 + 3 generated candidates
|
# assume vocab size 10, input length 5 + 3 generated candidates
|
||||||
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user