PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
This commit is contained in:
Abhishek 2024-10-10 20:15:34 -04:00
parent d2796f6f12
commit 39d2868e5c
No known key found for this signature in database
GPG Key ID: 36C864B70E8D4349

View File

@ -27,7 +27,6 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import typing_extensions
from .dynamic_module_utils import custom_object_save
@ -93,6 +92,7 @@ class FlashAttentionKwargs(TypedDict, total=False):
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
import torch
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]