mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
v4.39 deprecations 🧼 (#29492)
This commit is contained in:
parent
979fccc90f
commit
ffe60fdcd6
@ -336,12 +336,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Utilities
|
||||
|
||||
[[autodoc]] top_k_top_p_filtering
|
||||
|
||||
[[autodoc]] tf_top_k_top_p_filtering
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -335,12 +335,6 @@ generation_output[:2]
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Utilities
|
||||
|
||||
[[autodoc]] top_k_top_p_filtering
|
||||
|
||||
[[autodoc]] tf_top_k_top_p_filtering
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -330,12 +330,6 @@ generation_output[:2]
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Utilities
|
||||
|
||||
[[autodoc]] top_k_top_p_filtering
|
||||
|
||||
[[autodoc]] tf_top_k_top_p_filtering
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -1409,7 +1409,6 @@ else:
|
||||
"TypicalLogitsWarper",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
"top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
_import_structure["generation_utils"] = []
|
||||
@ -3814,7 +3813,6 @@ else:
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
"tf_top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
_import_structure["generation_tf_utils"] = []
|
||||
@ -6206,7 +6204,6 @@ if TYPE_CHECKING:
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.albert import (
|
||||
@ -8178,7 +8175,6 @@ if TYPE_CHECKING:
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
from .modeling_tf_utils import (
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
@ -138,14 +137,6 @@ class AccurateGELUActivation(nn.Module):
|
||||
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
||||
|
||||
|
||||
class SiLUActivation(nn.SiLU):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The SiLUActivation class has been deprecated and will be removed in v4.39. Please use nn.SiLU instead.",
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class MishActivation(nn.Module):
|
||||
"""
|
||||
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
||||
|
@ -88,7 +88,6 @@ else:
|
||||
]
|
||||
_import_structure["utils"] = [
|
||||
"GenerationMixin",
|
||||
"top_k_top_p_filtering",
|
||||
"GreedySearchEncoderDecoderOutput",
|
||||
"GreedySearchDecoderOnlyOutput",
|
||||
"SampleEncoderDecoderOutput",
|
||||
@ -130,7 +129,6 @@ else:
|
||||
]
|
||||
_import_structure["tf_utils"] = [
|
||||
"TFGenerationMixin",
|
||||
"tf_top_k_top_p_filtering",
|
||||
"TFGreedySearchDecoderOnlyOutput",
|
||||
"TFGreedySearchEncoderDecoderOutput",
|
||||
"TFSampleEncoderDecoderOutput",
|
||||
@ -241,7 +239,6 @@ if TYPE_CHECKING:
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -279,7 +276,6 @@ if TYPE_CHECKING:
|
||||
TFGreedySearchEncoderDecoderOutput,
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -3088,68 +3088,6 @@ class TFGenerationMixin:
|
||||
return generated
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimumber of tokens we keep per batch example in the output.
|
||||
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"`tf_top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TFTopKLogitsWarper` and "
|
||||
"`TFTopPLogitsWarper` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
logits_shape = shape_list(logits)
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
||||
if top_p < 1.0:
|
||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||
sorted_logits = tf.gather(
|
||||
logits, sorted_indices, axis=-1, batch_dims=1
|
||||
) # expects logits to be of dim (batch_size, vocab_size)
|
||||
|
||||
cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
|
||||
if min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove = tf.concat(
|
||||
[
|
||||
tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
|
||||
sorted_indices_to_remove[:, min_tokens_to_keep:],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove = tf.concat(
|
||||
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, :-1]],
|
||||
-1,
|
||||
)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
||||
return logits
|
||||
|
||||
|
||||
def scatter_values_on_batch_indices(values, batch_indices):
|
||||
shape = shape_list(batch_indices)
|
||||
# broadcast batch dim to shape
|
||||
|
@ -4810,47 +4810,6 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
|
||||
return outputs
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
filter_value: float = -float("Inf"),
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimumber of tokens we keep per batch example in the output.
|
||||
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
warnings.warn(
|
||||
"`top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TopKLogitsWarper` and `TopPLogitsWarper` "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if top_k > 0:
|
||||
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
if 0 <= top_p <= 1.0:
|
||||
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def _ranking_fast(
|
||||
context_hidden: torch.FloatTensor,
|
||||
next_hidden: torch.FloatTensor,
|
||||
|
@ -129,10 +129,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
return self._cos_cached
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
if seq_len is not None:
|
||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
|
||||
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
@ -151,17 +148,17 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
def forward(self, x, position_ids):
|
||||
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||
position_ids = position_ids.float() / self.scaling_factor
|
||||
cos, sin = super().forward(x, position_ids, seq_len)
|
||||
cos, sin = super().forward(x, position_ids)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
def forward(self, x, position_ids):
|
||||
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_position_embeddings:
|
||||
@ -173,7 +170,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||
|
||||
cos, sin = super().forward(x, position_ids, seq_len)
|
||||
cos, sin = super().forward(x, position_ids)
|
||||
return cos, sin
|
||||
|
||||
|
||||
|
@ -120,27 +120,10 @@ class OPTAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
|
||||
"""
|
||||
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
|
||||
warning and return that value, otherwise take the equivalent config.config_arg_name
|
||||
"""
|
||||
val = None
|
||||
if fn_arg_name in kwargs:
|
||||
logging.warning(
|
||||
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
|
||||
"v4.39. Please set it in the config instead"
|
||||
)
|
||||
val = kwargs.pop(fn_arg_name)
|
||||
else:
|
||||
val = getattr(config, config_arg_name)
|
||||
return val
|
||||
|
||||
self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
|
||||
self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
|
||||
self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
|
||||
self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.dropout = config.attention_dropout
|
||||
self.enable_bias = config.enable_bias
|
||||
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.is_causal = True
|
||||
|
@ -408,10 +408,6 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||
|
||||
|
||||
class PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -128,10 +128,6 @@ class TFTopPLogitsWarper(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
||||
|
||||
|
||||
class KerasMetricCallback(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
@ -41,7 +41,6 @@ if is_tf_available():
|
||||
TFBartForConditionalGeneration,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras
|
||||
|
||||
@ -49,102 +48,6 @@ if is_tensorflow_text_available():
|
||||
import tensorflow_text as text
|
||||
|
||||
|
||||
@require_tf
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
# tests whether the top_k_top_p_filtering function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = tf.convert_to_tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=tf.float32,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = tf.convert_to_tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||
dtype=tf.int32,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = tf.convert_to_tensor(
|
||||
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
|
||||
dtype=tf.float32,
|
||||
) # expected non filtered values as noted above
|
||||
|
||||
output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
|
||||
non_inf_output = output[output != -float("inf")]
|
||||
non_inf_idx = tf.cast(
|
||||
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))),
|
||||
dtype=tf.int32,
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
|
||||
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
|
@ -52,7 +52,6 @@ if is_torch_available():
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.generation import (
|
||||
@ -2345,133 +2344,6 @@ class GenerationTesterMixin:
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276,
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958,
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
# tests whether the function uses filter_value instead of default -inf
|
||||
def test_top_k_top_p_filtering_with_filter_value(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0.99, # get filtered by top-p filtering
|
||||
0.98, # get filtered by top-k filtering
|
||||
]
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
expected_output = torch.tensor(
|
||||
[[1, 1, 1, 0, 0]],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
|
||||
|
||||
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
||||
|
||||
def test_speculative_sampling(self):
|
||||
# assume vocab size 10, input length 5 + 3 generated candidates
|
||||
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
|
||||
|
Loading…
Reference in New Issue
Block a user