mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove extra tensor clone in PyTorch code (#36748)
* Use detach().clone() * Eliminate continuous() * Merge clone and other calls with to * Merge clone and other calls with to
This commit is contained in:
parent
121830ab47
commit
d68a91aebf
@ -2697,7 +2697,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
|
||||
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32)
|
||||
final_logits = outputs.logits[:, -1, :].float()
|
||||
candidate_premature_logits = {}
|
||||
for candidate_premature_layer in candidate_premature_layers:
|
||||
@ -2885,11 +2885,12 @@ class GenerationMixin:
|
||||
last_hidden_states = outputs.hidden_states[-1]
|
||||
|
||||
# next logit for contrastive search to select top-k candidate tokens
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
||||
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
|
||||
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
||||
# torch.float32 is needed to retain precision for later logits manipulations
|
||||
logit_for_next_step = outputs.logits[:, -1, :].to(
|
||||
copy=True, dtype=torch.float32, device=input_ids.device
|
||||
)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
@ -3297,10 +3298,9 @@ class GenerationMixin:
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
@ -3768,8 +3768,8 @@ class GenerationMixin:
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue
|
||||
|
||||
logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref
|
||||
logits = logits.to(input_ids.device)
|
||||
# Copy is needed to avoid keeping a hanging ref
|
||||
logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
|
||||
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
|
||||
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
|
||||
@ -4045,10 +4045,9 @@ class GenerationMixin:
|
||||
if output_scores:
|
||||
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
||||
if output_logits:
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
raw_logit_score = outputs.logits[:, -1, :].clone()
|
||||
raw_logit_score = raw_logit_score.to(input_ids.device)
|
||||
raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
@ -4067,8 +4066,9 @@ class GenerationMixin:
|
||||
# select outputs of beams of current group only
|
||||
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :].to(
|
||||
dtype=torch.float32, device=input_ids.device
|
||||
)
|
||||
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
@ -4322,11 +4322,10 @@ class GenerationMixin:
|
||||
cur_len = cur_len + 1
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
@ -4574,8 +4573,9 @@ class GenerationMixin:
|
||||
|
||||
# 2.3. Process the new logits
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
|
||||
new_logits = new_logits.to(input_ids.device)
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :].to(
|
||||
dtype=torch.float32, device=input_ids.device
|
||||
) # excludes the input prompt if present
|
||||
next_token_logits = new_logits.clone()
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
|
@ -446,7 +446,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
|
||||
|
||||
device = weight.device
|
||||
dtype = weight.dtype
|
||||
weight = weight.clone().float()
|
||||
weight = weight.to(copy=True, dtype=torch.float32)
|
||||
# Pad to Hadamard transform size
|
||||
weight = pad_to_block(weight, [1], hadamard_size)
|
||||
|
||||
|
@ -2205,12 +2205,12 @@ class JukeboxPrior(PreTrainedModel):
|
||||
loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
|
||||
|
||||
metrics = {
|
||||
"bpd": next_token_prediction_loss.clone().detach(),
|
||||
"encoder_loss": encoder_loss.clone().detach(),
|
||||
"next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
|
||||
"bpd": next_token_prediction_loss.detach().clone(),
|
||||
"encoder_loss": encoder_loss.detach().clone(),
|
||||
"next_token_prediction_loss": next_token_prediction_loss.detach().clone(),
|
||||
}
|
||||
if get_preds:
|
||||
metrics["preds"] = preds.clone().detach()
|
||||
metrics["preds"] = preds.detach().clone()
|
||||
if get_attn_weights:
|
||||
saved_attn_weights = self.prior.transformer.saved_attn_weights
|
||||
self.prior.transformer.set_record_attn(False)
|
||||
|
@ -148,7 +148,7 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True
|
||||
w3 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w3"]
|
||||
|
||||
experts_w1 = [
|
||||
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
|
||||
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
|
||||
for expert_idx in range(num_local_experts)
|
||||
]
|
||||
|
||||
@ -157,16 +157,16 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True
|
||||
state_dict[expert_key + ".weight"] = expert_block.clone()
|
||||
|
||||
experts_w2 = [
|
||||
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
|
||||
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
|
||||
for expert_idx in range(num_local_experts)
|
||||
]
|
||||
|
||||
for idx, expert_block in enumerate(experts_w2):
|
||||
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w2"
|
||||
state_dict[expert_key + ".weight"] = expert_block.T.clone().contiguous()
|
||||
state_dict[expert_key + ".weight"] = expert_block.T.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
experts_w3 = [
|
||||
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
|
||||
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
|
||||
for expert_idx in range(num_local_experts)
|
||||
]
|
||||
|
||||
|
@ -131,7 +131,7 @@ class VideoMAEEmbeddings(nn.Module):
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
# add position embeddings
|
||||
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
|
||||
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).detach().clone()
|
||||
# only keep visible patches
|
||||
# ~bool_masked_pos means visible
|
||||
if bool_masked_pos is not None:
|
||||
@ -856,7 +856,7 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
|
||||
if bool_masked_pos is None:
|
||||
raise ValueError("One must provided a boolean mask ")
|
||||
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
|
||||
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
|
||||
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).detach().clone()
|
||||
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
|
||||
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
|
||||
|
||||
|
@ -73,12 +73,12 @@ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0)
|
||||
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
W = layer.weight.index_select(dim, index).detach().clone()
|
||||
if layer.bias is not None:
|
||||
if dim == 1:
|
||||
b = layer.bias.clone().detach()
|
||||
b = layer.bias.detach().clone()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
b = layer.bias[index].detach().clone()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||
@ -137,11 +137,11 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
|
||||
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
W = layer.weight.index_select(dim, index).detach().clone()
|
||||
if dim == 0:
|
||||
b = layer.bias.clone().detach()
|
||||
b = layer.bias.detach().clone()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
b = layer.bias[index].detach().clone()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||
|
Loading…
Reference in New Issue
Block a user