from math import log from typing import List, Optional, Tuple, Union import torch from transformers import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel def _pre_process_input(input_ids): print(log(input_ids)) return input_ids # example where we need some deps and some functions class DummyModel(LlamaModel): def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: input_ids = _pre_process_input(input_ids) return super().forward( None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, )