Reducing memory usage: removing useless logits computation in generate() (#31292)

* Add .float() in all generation methods logit outputs

* Switch float-casting of logits to training only for main models

* Add `num_logits_to_keep` in Llama and add it by default in generate

* Apply style

* Add num_logits_to_keep as arg in prepare_input_for_generation

* Add support for Mistral

* Revert models except llama and mistral

* Fix default None value in _supports_num_logits_to_keep()

* Fix dimension of dummy input

* Add exception for prophetnet in _supports_num_logits_to_keep()

* Update _supports_num_logits_to_keep() to use inspect.signature()

* Add deprecation cycle + remove modification with pretraining_tp

* Apply style

* Add most used models

* Apply style

* Make `num_logits_to_keep` an int in all cases to remove if-else clause

* Add compile check for the warning

* Fix torch versions

* style

* Add gemma2

* Update warning version

* Add comment about .float operations in generation utils

* Add tests in GenerationTesterMixin and ModelTesterMixin

* Fix batch size for assisted decoding in tests

* fix small issues in test

* refacor test

* fix slicing removing dim issue

* Add nemotron support (should fix check-copy issue in CIs)

* Trigger new CIs

* Trigger new CIs

* Bump version

* Bump version in TODO

* Trigger CIs

* remove blank space

* Trigger CIs
This commit is contained in:
Cyril Vallez 2024-08-23 12:08:34 +02:00 committed by GitHub
parent d806fa3e92
commit 22e6f14525
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 428 additions and 41 deletions

View File

@ -119,6 +119,10 @@ class AssistedCandidateGenerator(CandidateGenerator):
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
# Remove potential default "num_logits_to_keep" key
if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
del assistant_kwargs["num_logits_to_keep"]
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:

View File

@ -1601,6 +1601,13 @@ class GenerationMixin:
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
def _supports_num_logits_to_keep(self) -> bool:
"""
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
to save memory. Checking it in this way allows to avoid using a new model attribute.
"""
return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
@ -1876,6 +1883,13 @@ class GenerationMixin:
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
# If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
# dynamically overrides this value as it can need more than the last token logits
if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs:
model_kwargs["num_logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. Prepare the cache.
@ -2412,8 +2426,9 @@ class GenerationMixin:
output_hidden_states=True,
)
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone()
final_logits = outputs.logits[:, -1, :]
# .float() is needed to retain precision for later logits manipulations
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
final_logits = outputs.logits[:, -1, :].float()
candidate_premature_logits = {}
for candidate_premature_layer in candidate_premature_layers:
candidate_premature_logits[candidate_premature_layer] = lm_head(
@ -2590,7 +2605,8 @@ class GenerationMixin:
# next logit for contrastive search to select top-k candidate tokens
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
logit_for_next_step = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
@ -2720,7 +2736,8 @@ class GenerationMixin:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
logits = outputs.logits[:, -1, :]
# .float() is needed to retain precision for later logits manipulations
logits = outputs.logits[:, -1, :].float()
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
@ -2966,7 +2983,8 @@ class GenerationMixin:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
@ -3210,7 +3228,8 @@ class GenerationMixin:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -3484,7 +3503,8 @@ class GenerationMixin:
# select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
@ -3739,7 +3759,8 @@ class GenerationMixin:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -3998,7 +4019,8 @@ class GenerationMixin:
outputs = self(**model_inputs)
# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
next_token_logits = new_logits.clone()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):

View File

@ -44,6 +44,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1024,6 +1025,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1032,6 +1034,11 @@ class CohereForCausalLM(CoherePreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1071,12 +1078,19 @@ class CohereForCausalLM(CoherePreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1109,6 +1123,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1166,6 +1181,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -1275,6 +1275,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""Forward function for causal language modeling.
@ -1284,6 +1285,11 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1328,7 +1334,8 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# No upscaling to float was ever done for Dbrx
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
@ -1380,6 +1387,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1437,6 +1445,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -43,6 +43,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1038,6 +1039,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1046,6 +1048,11 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1085,10 +1092,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1121,6 +1136,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1177,6 +1193,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -41,6 +41,7 @@ from ...utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -987,6 +988,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -995,6 +997,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1038,15 +1045,23 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
# TODO: remove the float() operation in v4.46
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1079,6 +1094,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1139,6 +1155,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -33,6 +33,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1520,6 +1521,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
r"""
Args:
@ -1528,6 +1530,12 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1591,11 +1599,18 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
@ -1623,7 +1638,13 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
num_logits_to_keep=0,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
@ -1682,6 +1703,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
"pixel_values": pixel_values,
"pixel_attention_mask": pixel_attention_mask,
"image_hidden_states": image_hidden_states,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -50,6 +50,7 @@ from ...utils.import_utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_mamba_ssm_available,
is_torchdynamo_compiling,
)
from .configuration_jamba import JambaConfig
@ -1497,10 +1498,17 @@ class JambaForCausalLM(JambaPreTrainedModel):
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
if labels is None and not is_torchdynamo_compiling:
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# TODO: remove the float() operations in v4.46
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

View File

@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1253,6 +1254,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1261,6 +1263,11 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
"""
@ -1285,11 +1292,18 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1339,6 +1353,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
output_router_logits=False,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1371,6 +1386,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
"use_cache": use_cache,
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -44,6 +44,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1146,6 +1147,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1154,6 +1156,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1198,11 +1205,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1235,6 +1249,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1292,6 +1307,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -996,6 +997,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1004,6 +1006,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1044,11 +1051,18 @@ class MistralForCausalLM(MistralPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1081,6 +1095,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1115,6 +1130,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -43,6 +43,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1233,6 +1234,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1241,6 +1243,11 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1286,11 +1293,18 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1339,6 +1353,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
output_router_logits=False,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1371,6 +1386,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
"use_cache": use_cache,
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1029,6 +1030,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1037,6 +1039,11 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1076,11 +1083,19 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
# TODO: remove the float() operation in v4.46
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1113,6 +1128,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1170,6 +1186,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1068,6 +1069,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1076,6 +1078,11 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1116,11 +1123,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1153,6 +1167,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1210,6 +1225,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -885,6 +885,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -893,6 +894,11 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -933,7 +939,8 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# No upscaling to float was ever done for Persimmon
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
@ -970,6 +977,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1027,6 +1035,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -41,6 +41,7 @@ from ...utils import (
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1169,6 +1170,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1177,6 +1179,11 @@ class PhiForCausalLM(PhiPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1217,11 +1224,18 @@ class PhiForCausalLM(PhiPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1255,6 +1269,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1312,6 +1327,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -40,6 +40,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1210,6 +1211,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1218,6 +1220,11 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1257,11 +1264,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1295,6 +1309,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1352,6 +1367,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1067,6 +1068,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1075,6 +1077,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1115,11 +1122,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1153,6 +1167,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1210,6 +1225,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -43,6 +43,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1244,6 +1245,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1252,6 +1254,11 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1296,11 +1303,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1349,6 +1363,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1406,6 +1421,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -1164,6 +1164,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1172,6 +1173,11 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1211,7 +1217,8 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# No upscaling to float was ever done for StableLm
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
@ -1248,6 +1255,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1305,6 +1313,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -1043,6 +1044,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1051,6 +1053,11 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
@ -1091,11 +1098,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -1129,6 +1143,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1186,6 +1201,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -1828,6 +1828,62 @@ class GenerationTesterMixin:
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
def test_generate_methods_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
}
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
assistant_model = model
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
}
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences

View File

@ -4824,6 +4824,27 @@ class ModelTesterMixin:
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
def test_forward_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size
model = model_class(config).to(device=torch_device).eval()
# num_logits_to_keep=0 is a special case meaning "keep all logits"
all_logits = model(**inputs, num_logits_to_keep=0).logits
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
# Assert all shapes are correct
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
# Assert the last tokens are actually the same
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
global_rng = random.Random()