mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Reducing memory usage: removing useless logits computation in generate() (#31292)
* Add .float() in all generation methods logit outputs * Switch float-casting of logits to training only for main models * Add `num_logits_to_keep` in Llama and add it by default in generate * Apply style * Add num_logits_to_keep as arg in prepare_input_for_generation * Add support for Mistral * Revert models except llama and mistral * Fix default None value in _supports_num_logits_to_keep() * Fix dimension of dummy input * Add exception for prophetnet in _supports_num_logits_to_keep() * Update _supports_num_logits_to_keep() to use inspect.signature() * Add deprecation cycle + remove modification with pretraining_tp * Apply style * Add most used models * Apply style * Make `num_logits_to_keep` an int in all cases to remove if-else clause * Add compile check for the warning * Fix torch versions * style * Add gemma2 * Update warning version * Add comment about .float operations in generation utils * Add tests in GenerationTesterMixin and ModelTesterMixin * Fix batch size for assisted decoding in tests * fix small issues in test * refacor test * fix slicing removing dim issue * Add nemotron support (should fix check-copy issue in CIs) * Trigger new CIs * Trigger new CIs * Bump version * Bump version in TODO * Trigger CIs * remove blank space * Trigger CIs
This commit is contained in:
parent
d806fa3e92
commit
22e6f14525
@ -119,6 +119,10 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
|
||||
)
|
||||
|
||||
# Remove potential default "num_logits_to_keep" key
|
||||
if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
|
||||
del assistant_kwargs["num_logits_to_keep"]
|
||||
|
||||
if "assistant_encoder_outputs" in model_kwargs:
|
||||
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
elif assistant_model.config.is_encoder_decoder:
|
||||
|
@ -1601,6 +1601,13 @@ class GenerationMixin:
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
)
|
||||
|
||||
def _supports_num_logits_to_keep(self) -> bool:
|
||||
"""
|
||||
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
|
||||
to save memory. Checking it in this way allows to avoid using a new model attribute.
|
||||
"""
|
||||
return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@ -1876,6 +1883,13 @@ class GenerationMixin:
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole
|
||||
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
|
||||
# dynamically overrides this value as it can need more than the last token logits
|
||||
if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs:
|
||||
model_kwargs["num_logits_to_keep"] = 1
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. Prepare the cache.
|
||||
@ -2412,8 +2426,9 @@ class GenerationMixin:
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone()
|
||||
final_logits = outputs.logits[:, -1, :]
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
|
||||
final_logits = outputs.logits[:, -1, :].float()
|
||||
candidate_premature_logits = {}
|
||||
for candidate_premature_layer in candidate_premature_layers:
|
||||
candidate_premature_logits[candidate_premature_layer] = lm_head(
|
||||
@ -2590,7 +2605,8 @@ class GenerationMixin:
|
||||
# next logit for contrastive search to select top-k candidate tokens
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
||||
# (the clone itself is always small)
|
||||
logit_for_next_step = outputs.logits[:, -1, :].clone()
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
@ -2720,7 +2736,8 @@ class GenerationMixin:
|
||||
next_hidden = outputs.hidden_states[-1]
|
||||
full_hidden_states = outputs.hidden_states
|
||||
|
||||
logits = outputs.logits[:, -1, :]
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
logits = outputs.logits[:, -1, :].float()
|
||||
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
||||
|
||||
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
||||
@ -2966,7 +2983,8 @@ class GenerationMixin:
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits[:, -1, :].clone()
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
@ -3210,7 +3228,8 @@ class GenerationMixin:
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits[:, -1, :].clone()
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
@ -3484,7 +3503,8 @@ class GenerationMixin:
|
||||
|
||||
# select outputs of beams of current group only
|
||||
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
|
||||
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
@ -3739,7 +3759,8 @@ class GenerationMixin:
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits[:, -1, :].clone()
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
@ -3998,7 +4019,8 @@ class GenerationMixin:
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
# 2.3. Process the new logits
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
|
||||
next_token_logits = new_logits.clone()
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1024,6 +1025,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1032,6 +1034,11 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1071,12 +1078,19 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1109,6 +1123,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1166,6 +1181,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -1275,6 +1275,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""Forward function for causal language modeling.
|
||||
|
||||
@ -1284,6 +1285,11 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1328,7 +1334,8 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
# No upscaling to float was ever done for Dbrx
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@ -1380,6 +1387,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1437,6 +1445,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1038,6 +1039,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1046,6 +1048,11 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1085,10 +1092,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1121,6 +1136,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1177,6 +1193,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -987,6 +988,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -995,6 +997,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1038,15 +1045,23 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1079,6 +1094,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1139,6 +1155,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -33,6 +33,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1520,6 +1521,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1528,6 +1530,12 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).
|
||||
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
||||
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1591,11 +1599,18 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
@ -1623,7 +1638,13 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
@ -1682,6 +1703,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_attention_mask,
|
||||
"image_hidden_states": image_hidden_states,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -50,6 +50,7 @@ from ...utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_mamba_ssm_available,
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from .configuration_jamba import JambaConfig
|
||||
|
||||
@ -1497,10 +1498,17 @@ class JambaForCausalLM(JambaPreTrainedModel):
|
||||
logits = self.lm_head(hidden_states)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
|
||||
if labels is None and not is_torchdynamo_compiling:
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# TODO: remove the float() operations in v4.46
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
@ -37,6 +37,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1253,6 +1254,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1261,6 +1263,11 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
@ -1285,11 +1292,18 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1339,6 +1353,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
output_router_logits=False,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1371,6 +1386,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1146,6 +1147,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1154,6 +1156,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1198,11 +1205,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1235,6 +1249,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1292,6 +1307,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -996,6 +997,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1004,6 +1006,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1044,11 +1051,18 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1081,6 +1095,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1115,6 +1130,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1233,6 +1234,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1241,6 +1243,11 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1286,11 +1293,18 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1339,6 +1353,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
output_router_logits=False,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1371,6 +1386,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1029,6 +1030,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1037,6 +1039,11 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1076,11 +1083,19 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1113,6 +1128,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1170,6 +1186,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1068,6 +1069,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1076,6 +1078,11 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1116,11 +1123,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1153,6 +1167,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1210,6 +1225,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -885,6 +885,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -893,6 +894,11 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -933,7 +939,8 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
# No upscaling to float was ever done for Persimmon
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@ -970,6 +977,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1027,6 +1035,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
get_torch_version,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1169,6 +1170,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1177,6 +1179,11 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1217,11 +1224,18 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1255,6 +1269,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1312,6 +1327,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1210,6 +1211,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1218,6 +1220,11 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1257,11 +1264,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1295,6 +1309,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1352,6 +1367,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1067,6 +1068,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1075,6 +1077,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1115,11 +1122,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1153,6 +1167,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1210,6 +1225,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1244,6 +1245,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1252,6 +1254,11 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1296,11 +1303,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1349,6 +1363,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1406,6 +1421,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -1164,6 +1164,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1172,6 +1173,11 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1211,7 +1217,8 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
# No upscaling to float was ever done for StableLm
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@ -1248,6 +1255,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1305,6 +1313,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -1043,6 +1044,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1051,6 +1053,11 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1091,11 +1098,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if labels is None and not is_torchdynamo_compiling():
|
||||
logger.warning_once(
|
||||
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
||||
)
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO: remove the float() operation in v4.46
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -1129,6 +1143,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1186,6 +1201,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -1828,6 +1828,62 @@ class GenerationTesterMixin:
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
assistant_model = model
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
"assistant_model": assistant_model,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
@ -4824,6 +4824,27 @@ class ModelTesterMixin:
|
||||
self.assertTrue(record_time < 0.15 * graph_warmup_time)
|
||||
self.assertTrue(opt_time < record_time)
|
||||
|
||||
def test_forward_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, num_logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
|
||||
|
||||
# Assert all shapes are correct
|
||||
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
|
||||
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
|
||||
|
||||
# Assert the last tokens are actually the same
|
||||
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user