mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* current working example! * commit regex and result file * update * nit * push the conversion file * oups * roadmap and nits * attempt diffs for 3 files * persimmon * nit * add diff file that is the same as the modeling_llama.py * fix rope nits * updates * updates with converted versions * give some breathing space to the code * delete * update * update * push the actual result * update regex patterns * update regex patterns * fix some issues * fix some issues * fix some issues * updates * updates * updates * updates * updates * revert changes done to llama * updates * update gemma * updates * oups * current state * current state * update * ouiiii * nit * clear diffs * nit * fixup * update * doc 🚀 * 🔥 * for now use gemma * deal with comments * style * handle funtions * deal with assigns * todos * process inheritage * keep decorators? * 🤗 * deal with duplicates * fixup * correctly remove duplicate code * run ruff post script * ruff deals pretty well with imports, let's leave it to him * ah maybe not lol * for now remove all imports from child. * nit * conversion of llama * okay * convert starcoder2 * synch with main * update llama diff * updates * https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff * updates * okay actual state * non zero exit * update! * revert unrelated * remove other diff files * updates * cleanup * update * less diff! * stash * current updates * updates * No need for call * finished fining deps * update * current changes * current state * current state * new status * nit * finally * fixes * nits * order is now expected * use logger info instead of prints * fixup * up * nit * update * nits * update * correct merge * update * update * update * add warning * update caution message * update * better merging strategy * copy class statements :wink * fixups * nits * update * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * nits * smaller header * do cleanup some stuff * even simpler header? * fixup * updates * ruff * update examples * nit * TODO * state * OUUUUUUF * current state * nits * final state * add a readme * fixup * remove diff llama * fix * nit * dummy noy funny * ruff format tests src utils --check * everless diffs * less diffs and fix test * fixes * naming nit? * update converter and add supper example * nits * updated for function signatures * update * update * add converted dummies * autoformat * single target assign fix * fixup * fix some imports * fixes * don't push them * `# noqa: F841` --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
from math import log
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from transformers import Cache
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
|
|
|
|
def _pre_process_input(input_ids):
|
|
print(log(input_ids))
|
|
return input_ids
|
|
|
|
|
|
# example where we need some deps and some functions
|
|
class DummyModel(LlamaModel):
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
input_ids = _pre_process_input(input_ids)
|
|
|
|
return super().forward(
|
|
None,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
inputs_embeds,
|
|
use_cache,
|
|
output_attentions,
|
|
output_hidden_states,
|
|
return_dict,
|
|
cache_position,
|
|
)
|