Remove extra tensor clone in PyTorch code (#36748)

* Use detach().clone()

* Eliminate continuous()

* Merge clone and other calls with to

* Merge clone and other calls with to
This commit is contained in:
cyyever 2025-03-26 01:42:15 +08:00 committed by GitHub
parent 121830ab47
commit d68a91aebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 37 deletions

View File

@ -2697,7 +2697,7 @@ class GenerationMixin:
)
# .float() is needed to retain precision for later logits manipulations
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32)
final_logits = outputs.logits[:, -1, :].float()
candidate_premature_logits = {}
for candidate_premature_layer in candidate_premature_layers:
@ -2885,11 +2885,12 @@ class GenerationMixin:
last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
logit_for_next_step = logit_for_next_step.to(input_ids.device)
# torch.float32 is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
@ -3297,10 +3298,9 @@ class GenerationMixin:
if synced_gpus and this_peer_finished:
continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
@ -3768,8 +3768,8 @@ class GenerationMixin:
if synced_gpus and this_peer_finished:
continue
logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref
logits = logits.to(input_ids.device)
# Copy is needed to avoid keeping a hanging ref
logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
@ -4045,10 +4045,9 @@ class GenerationMixin:
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
if output_logits:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
raw_logit_score = outputs.logits[:, -1, :].clone()
raw_logit_score = raw_logit_score.to(input_ids.device)
raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
@ -4067,8 +4066,9 @@ class GenerationMixin:
# select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_logits = outputs.logits[batch_group_indices, -1, :].to(
dtype=torch.float32, device=input_ids.device
)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
@ -4322,11 +4322,10 @@ class GenerationMixin:
cur_len = cur_len + 1
continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -4574,8 +4573,9 @@ class GenerationMixin:
# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
new_logits = new_logits.to(input_ids.device)
new_logits = outputs.logits[:, -candidate_length - 1 :].to(
dtype=torch.float32, device=input_ids.device
) # excludes the input prompt if present
next_token_logits = new_logits.clone()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):

View File

@ -446,7 +446,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
device = weight.device
dtype = weight.dtype
weight = weight.clone().float()
weight = weight.to(copy=True, dtype=torch.float32)
# Pad to Hadamard transform size
weight = pad_to_block(weight, [1], hadamard_size)

View File

@ -2205,12 +2205,12 @@ class JukeboxPrior(PreTrainedModel):
loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
metrics = {
"bpd": next_token_prediction_loss.clone().detach(),
"encoder_loss": encoder_loss.clone().detach(),
"next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
"bpd": next_token_prediction_loss.detach().clone(),
"encoder_loss": encoder_loss.detach().clone(),
"next_token_prediction_loss": next_token_prediction_loss.detach().clone(),
}
if get_preds:
metrics["preds"] = preds.clone().detach()
metrics["preds"] = preds.detach().clone()
if get_attn_weights:
saved_attn_weights = self.prior.transformer.saved_attn_weights
self.prior.transformer.set_record_attn(False)

View File

@ -148,7 +148,7 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True
w3 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w3"]
experts_w1 = [
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
for expert_idx in range(num_local_experts)
]
@ -157,16 +157,16 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True
state_dict[expert_key + ".weight"] = expert_block.clone()
experts_w2 = [
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
for expert_idx in range(num_local_experts)
]
for idx, expert_block in enumerate(experts_w2):
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w2"
state_dict[expert_key + ".weight"] = expert_block.T.clone().contiguous()
state_dict[expert_key + ".weight"] = expert_block.T.clone(memory_format=torch.contiguous_format)
experts_w3 = [
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
for expert_idx in range(num_local_experts)
]

View File

@ -131,7 +131,7 @@ class VideoMAEEmbeddings(nn.Module):
embeddings = self.patch_embeddings(pixel_values)
# add position embeddings
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).detach().clone()
# only keep visible patches
# ~bool_masked_pos means visible
if bool_masked_pos is not None:
@ -856,7 +856,7 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
if bool_masked_pos is None:
raise ValueError("One must provided a boolean mask ")
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).detach().clone()
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)

View File

@ -73,12 +73,12 @@ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0)
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
W = layer.weight.index_select(dim, index).detach().clone()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
b = layer.bias.detach().clone()
else:
b = layer.bias[index].clone().detach()
b = layer.bias[index].detach().clone()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
@ -137,11 +137,11 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
W = layer.weight.index_select(dim, index).detach().clone()
if dim == 0:
b = layer.bias.clone().detach()
b = layer.bias.detach().clone()
else:
b = layer.bias[index].clone().detach()
b = layer.bias[index].detach().clone()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)