transformers/examples/modular-transformers/modular_switch_function.py
Cyril Vallez 91be6a5eb2
Modular: support for importing functions from any file (#35692)
* fix function imports

* improve comment

* Update modeling_switch_function.py

* make checks more robust

* improvement

* rename

* final test update
2025-01-16 16:37:53 +00:00

11 lines
488 B
Python

# Note that llama and cohere have different definitions for rotate_half
from transformers.models.cohere.modeling_cohere import rotate_half # noqa
from transformers.models.llama.modeling_llama import LlamaAttention
# When following LlamaAttention dependencies, we will grab the function `rotate_half` defined
# in `modeling_llama.py`. But here we imported it explicitly from Cohere, so it should use Cohere's
# definition instead
class SwitchFunctionAttention(LlamaAttention):
pass