from typing import List, Optional, Tuple, Union import torch from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel from ...cache_utils import Cache # example where we need some deps and some functions class SuperModel(LlamaModel): def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: out = super().forward( input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, ) out.logits *= 2**4 return out