# Note that zamba does not have the `apply_rotary_pos_emb` function! from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.models.zamba.modeling_zamba import ZambaAttention # When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present # by default as it is absent from the class definition (and the file altogether). # Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but # `rotate_half` as well as a dependency from the imported function!! class TestAttention(ZambaAttention): def __init__(self): pass def forward(self): _ = apply_rotary_pos_emb(1, 1, 1, 1)