mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Create the arange tensor on device for enabling CUDA-Graph for Clip Encoder (#19503)
* create the arange tensor on device for enabling CUDA-Graph at higher-performace for SD * sync Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
6cd8676cf3
commit
f6fa0f0bf0
@ -662,7 +662,7 @@ class CLIPTextTransformer(nn.Module):
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
|
@ -1134,7 +1134,7 @@ class GroupViTTextTransformer(nn.Module):
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
|
Loading…
Reference in New Issue
Block a user