update base_model_prefix

This commit is contained in:
Shuming Hu 2025-07-02 16:39:54 +00:00
parent a6820acfa7
commit d50b8c9684
2 changed files with 78 additions and 2 deletions

View File

@ -90,7 +90,7 @@ class PerceptionLMMultiModalProjector(nn.Module):
@auto_docstring
class PerceptionLMPreTrainedModel(PreTrainedModel):
config_class = PerceptionLMConfig
base_model_prefix = ""
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
@ -234,6 +234,43 @@ class PerceptionLMModel(PerceptionLMPreTrainedModel):
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMModelOutputWithPast]:
"""
Forward pass of the PerceptionLM model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMModelOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -394,6 +431,45 @@ class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, Generati
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
"""
Forward pass for the PerceptionLMForConditionalGeneration model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
labels (`torch.LongTensor`, *optional*):
Labels for computing the language modeling loss.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMCausalLMOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,

View File

@ -99,7 +99,7 @@ class PerceptionLMMultiModalProjector(nn.Module):
class PerceptionLMPreTrainedModel(LlavaPreTrainedModel):
base_model_prefix = ""
base_model_prefix = "model"
class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast):