Create the arange tensor on device for enabling CUDA-Graph for Clip Encoder (#19503)

* create the arange tensor on device for enabling CUDA-Graph at higher-performace for SD

* sync

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Reza Yazdani 2022-10-12 14:32:50 -07:00 committed by GitHub
parent 6cd8676cf3
commit f6fa0f0bf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -662,7 +662,7 @@ class CLIPTextTransformer(nn.Module):
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
]
if not return_dict:

View File

@ -1134,7 +1134,7 @@ class GroupViTTextTransformer(nn.Module):
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
]
if not return_dict: