mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix more inefficient PT operations (#37060)
* Fix inefficient operations * Remove cpu() call * Reorder detach() * Reorder detach() * tolist without detach * item without detach * Update src/transformers/models/rag/modeling_rag.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/encodec/test_modeling_encodec.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Use detach().cpu().numpy * Revert some numpy operations * More fixes --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
a1e389e637
commit
786d9c5ed9
@ -204,7 +204,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
time_to_first_token = end - start
|
||||
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
cache_position = torch.tensor([seq_length], device=device)
|
||||
### First compile, decoding
|
||||
@ -217,7 +217,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
time_to_second_token = end - start
|
||||
logger.info(f"completed second compile generation in: {time_to_second_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Second compile, decoding
|
||||
start = perf_counter()
|
||||
@ -229,13 +229,13 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
time_to_third_token = end - start
|
||||
logger.info(f"completed third compile forward in: {time_to_third_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Using cuda graphs decoding
|
||||
|
||||
start = perf_counter()
|
||||
for _ in range(1, num_tokens_to_generate):
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
|
@ -68,7 +68,7 @@ def set_seed(args):
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
return tensor.tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
|
@ -129,7 +129,7 @@ class AgentImage(AgentType, ImageType):
|
||||
return self._raw
|
||||
|
||||
if self._tensor is not None:
|
||||
array = self._tensor.cpu().detach().numpy()
|
||||
array = self._tensor.detach().cpu().numpy()
|
||||
return Image.fromarray((255 - array * 255).astype(np.uint8))
|
||||
|
||||
def to_string(self):
|
||||
@ -147,7 +147,7 @@ class AgentImage(AgentType, ImageType):
|
||||
return self._path
|
||||
|
||||
if self._tensor is not None:
|
||||
array = self._tensor.cpu().detach().numpy()
|
||||
array = self._tensor.detach().cpu().numpy()
|
||||
|
||||
# There is likely simpler than load into image into save
|
||||
img = Image.fromarray((255 - array * 255).astype(np.uint8))
|
||||
|
@ -64,4 +64,4 @@ class TextToSpeechTool(PipelineTool):
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
||||
return self.post_processor(outputs).detach().cpu()
|
||||
|
@ -612,7 +612,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
|
||||
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
|
||||
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].tolist())
|
||||
if completes_constraint:
|
||||
if beam_indices is not None:
|
||||
beam_index = beam_indices[batch_beam_idx]
|
||||
@ -718,19 +718,19 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
# hypotheses.
|
||||
|
||||
topk_state = topk_contraint_states[seq_idx]
|
||||
topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
|
||||
topk_state.reset(full_hypotheses[seq_idx].tolist())
|
||||
|
||||
advance_state = advance_constraint_states[seq_idx]
|
||||
advance_state.reset(pre_seq.cpu().tolist())
|
||||
advance_state.reset(pre_seq.tolist())
|
||||
|
||||
if not advance_state.completed:
|
||||
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
|
||||
for advance_token in advance_tokens:
|
||||
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
||||
new_state = advance_state.copy(stateful=True)
|
||||
new_state.add(advance_token.cpu().tolist())
|
||||
new_state.add(advance_token.tolist())
|
||||
|
||||
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
|
||||
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).tolist()
|
||||
if advance_seq not in track_new["new_seqs"]:
|
||||
# prevent duplicates, which are basically bound to happen in this process.
|
||||
track_new["new_seqs"].append(advance_seq)
|
||||
@ -763,7 +763,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
|
||||
advance_state = advance_constraint_states[seq_idx]
|
||||
|
||||
advance_seq = advance_seq.cpu().tolist()
|
||||
advance_seq = advance_seq.tolist()
|
||||
|
||||
advance_state.reset(advance_seq)
|
||||
if advance_seq not in track_new["new_seqs"]:
|
||||
@ -843,7 +843,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
|
||||
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
||||
completes_constraint = self.check_completes_constraints(final_tokens.tolist())
|
||||
if completes_constraint:
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
|
@ -3768,7 +3768,7 @@ class GenerationMixin:
|
||||
device=input_ids.device,
|
||||
)
|
||||
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
|
||||
sequences = running_sequences.clone().detach()
|
||||
sequences = running_sequences.detach().clone()
|
||||
|
||||
# per batch, beam-item score, logprobs
|
||||
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
||||
@ -3789,7 +3789,7 @@ class GenerationMixin:
|
||||
running_beam_indices = torch.full(
|
||||
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
|
||||
)
|
||||
beam_indices = running_beam_indices.clone().detach()
|
||||
beam_indices = running_beam_indices.detach().clone()
|
||||
|
||||
# 4. run the generation loop
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
@ -5100,7 +5100,7 @@ def _dola_select_contrast(
|
||||
|
||||
# 6. Reduce the batchmean
|
||||
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
|
||||
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
|
||||
premature_layer = candidate_premature_layers[int(js_divs.argmax().item())]
|
||||
|
||||
base_logits = candidate_premature_logits[premature_layer]
|
||||
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
|
||||
|
@ -175,8 +175,8 @@ class RTDetrLoss(nn.Module):
|
||||
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0)
|
||||
ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
ious = torch.diag(ious).detach()
|
||||
ious, _ = box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))
|
||||
ious = torch.diag(ious)
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
@ -190,7 +190,7 @@ class RTDetrLoss(nn.Module):
|
||||
target_score_original[idx] = ious.to(target_score_original.dtype)
|
||||
target_score = target_score_original.unsqueeze(-1) * target
|
||||
|
||||
pred_score = F.sigmoid(src_logits).detach()
|
||||
pred_score = F.sigmoid(src_logits.detach())
|
||||
weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
|
||||
|
@ -123,7 +123,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -148,7 +148,7 @@ def get_alignment(music_tokens, labels, prior, config):
|
||||
del w_hop
|
||||
weights = torch.cat(w_hops, dim=0)
|
||||
del w_hops
|
||||
alignment_hop = weights.float().cpu().numpy()
|
||||
alignment_hop = weights.to(device="cpu", dtype=torch.float).numpy()
|
||||
del weights
|
||||
|
||||
# alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens)
|
||||
|
@ -1568,7 +1568,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
return tuple(tup.tolist())
|
||||
|
||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||
# we filter empty queries and detection below threshold
|
||||
@ -1677,7 +1677,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
return tuple(tup.tolist())
|
||||
|
||||
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
||||
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
||||
|
@ -831,7 +831,7 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
return tuple(tup.tolist())
|
||||
|
||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||
# we filter empty queries and detection below threshold
|
||||
@ -940,7 +940,7 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
return tuple(tup.tolist())
|
||||
|
||||
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
||||
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
||||
|
@ -120,7 +120,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -951,7 +951,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin):
|
||||
>>> height = image_processor.size["height"]
|
||||
>>> width = image_processor.size["width"]
|
||||
|
||||
>>> samples = output[:, 1:].cpu().detach().numpy()
|
||||
>>> samples = output[:, 1:].detach().cpu().numpy()
|
||||
>>> samples_img = [
|
||||
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
|
||||
... ] # convert color cluster tokens back to pixels
|
||||
|
@ -182,7 +182,7 @@ class MgpstrProcessor(ProcessorMixin):
|
||||
for index in range(batch_size):
|
||||
pred_eos = preds_str[index].find(eos_str)
|
||||
pred = preds_str[index][:pred_eos]
|
||||
pred_index = preds_index[index].cpu().tolist()
|
||||
pred_index = preds_index[index].tolist()
|
||||
pred_eos_index = pred_index.index(eos_token) if eos_token in pred_index else -1
|
||||
pred_max_prob = preds_max_prob[index][: pred_eos_index + 1]
|
||||
confidence_score = pred_max_prob.cumprod(dim=0)[-1] if pred_max_prob.nelement() != 0 else 0.0
|
||||
|
@ -1167,7 +1167,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -583,7 +583,7 @@ class RagModel(RagPreTrainedModel):
|
||||
|
||||
retriever_outputs = self.retriever(
|
||||
input_ids,
|
||||
question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(),
|
||||
question_encoder_last_hidden_state.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=self.generator.config.prefix,
|
||||
n_docs=n_docs,
|
||||
return_tensors="pt",
|
||||
@ -974,7 +974,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
context_input_ids = self.retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=self.generator.config.prefix,
|
||||
n_docs=n_docs,
|
||||
return_tensors="pt",
|
||||
@ -1462,7 +1462,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
out = self.retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=self.generator.config.prefix,
|
||||
n_docs=n_docs,
|
||||
return_tensors="pt",
|
||||
|
@ -123,7 +123,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -112,7 +112,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -142,7 +142,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -155,7 +155,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -172,7 +172,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -276,7 +276,7 @@ class UnivNetFeatureExtractor(SequenceFeatureExtractor):
|
||||
`List[np.ndarray]`: A ragged list of 1D waveform arrays with padding removed.
|
||||
"""
|
||||
# Collapse the batched waveform tensor to a list of 1D audio waveforms
|
||||
waveforms = [waveform.detach().clone().cpu().numpy() for waveform in waveforms]
|
||||
waveforms = [waveform.detach().to(device="cpu", copy=True).numpy() for waveform in waveforms]
|
||||
|
||||
if waveform_lengths is not None:
|
||||
waveforms = [waveform[: waveform_lengths[i]] for i, waveform in enumerate(waveforms)]
|
||||
|
@ -131,7 +131,9 @@ class VideoMAEEmbeddings(nn.Module):
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
# add position embeddings
|
||||
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).detach().clone()
|
||||
embeddings = embeddings + self.position_embeddings.detach().type_as(embeddings).to(
|
||||
device=embeddings.device, copy=True
|
||||
)
|
||||
# only keep visible patches
|
||||
# ~bool_masked_pos means visible
|
||||
if bool_masked_pos is not None:
|
||||
@ -856,7 +858,7 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
|
||||
if bool_masked_pos is None:
|
||||
raise ValueError("One must provided a boolean mask ")
|
||||
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
|
||||
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).detach().clone()
|
||||
expanded_position_embeddings = expanded_position_embeddings.detach().to(device=pixel_values.device, copy=True)
|
||||
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
|
||||
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
|
||||
|
||||
|
@ -190,7 +190,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -147,7 +147,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -165,7 +165,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -150,7 +150,7 @@ def _compute_mask_indices(
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
attention_mask.detach().sum(-1).tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
@ -18,6 +18,8 @@ from .base import Pipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
||||
from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
|
||||
|
||||
@ -213,7 +215,7 @@ class TextToAudioPipeline(Pipeline):
|
||||
waveform = waveform["waveform"]
|
||||
elif isinstance(waveform, tuple):
|
||||
waveform = waveform[0]
|
||||
output_dict["audio"] = waveform.cpu().float().numpy()
|
||||
output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy()
|
||||
output_dict["sampling_rate"] = self.sampling_rate
|
||||
|
||||
return output_dict
|
||||
|
@ -2524,9 +2524,7 @@ class Trainer:
|
||||
else:
|
||||
input_tokens = inputs[main_input_name].numel()
|
||||
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
|
||||
self.state.num_input_tokens_seen += (
|
||||
self.accelerator.gather(input_tokens).sum().cpu().item()
|
||||
)
|
||||
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
@ -3076,7 +3074,7 @@ class Trainer:
|
||||
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
if grad_norm is not None:
|
||||
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||
if learning_rate is not None:
|
||||
logs["learning_rate"] = learning_rate
|
||||
else:
|
||||
@ -4559,7 +4557,7 @@ class Trainer:
|
||||
if has_labels or loss_without_labels:
|
||||
with self.compute_loss_context_manager():
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
loss = loss.detach().mean()
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
|
@ -1204,7 +1204,7 @@ if is_sagemaker_mp_enabled():
|
||||
return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
|
||||
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
||||
# which is also the name of the decorator so Python is confused.
|
||||
return tensor.concat().detach().cpu()
|
||||
return tensor.detach().concat().cpu()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -351,9 +351,9 @@ class Seq2SeqTrainer(Trainer):
|
||||
with self.compute_loss_context_manager():
|
||||
outputs = model(**inputs)
|
||||
if self.label_smoother is not None:
|
||||
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||
loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
|
||||
else:
|
||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
|
||||
else:
|
||||
loss = None
|
||||
|
||||
|
@ -276,7 +276,7 @@ def to_py_obj(obj):
|
||||
return [to_py_obj(o) for o in obj]
|
||||
|
||||
framework_to_py_obj = {
|
||||
"pt": lambda obj: obj.detach().cpu().tolist(),
|
||||
"pt": lambda obj: obj.tolist(),
|
||||
"tf": lambda obj: obj.numpy().tolist(),
|
||||
"jax": lambda obj: np.asarray(obj).tolist(),
|
||||
"np": lambda obj: obj.tolist(),
|
||||
|
@ -564,7 +564,7 @@ def is_torch_fp16_available_on_device(device):
|
||||
import torch
|
||||
|
||||
try:
|
||||
x = torch.zeros(2, 2, dtype=torch.float16).to(device)
|
||||
x = torch.zeros(2, 2, dtype=torch.float16, device=device)
|
||||
_ = x @ x
|
||||
|
||||
# At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many
|
||||
|
@ -522,9 +522,9 @@ class ConstrainedBeamSearchTester:
|
||||
# set to same device. we don't care what device.
|
||||
|
||||
if not isinstance(tensor_1, list):
|
||||
tensor_1 = tensor_1.cpu().tolist()
|
||||
tensor_1 = tensor_1.tolist()
|
||||
if not isinstance(tensor_2, list):
|
||||
tensor_2 = tensor_2.cpu().tolist()
|
||||
tensor_2 = tensor_2.tolist()
|
||||
|
||||
in_order = len(tensor_1) <= len(tensor_2)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
|
@ -2595,9 +2595,9 @@ class GenerationTesterMixin:
|
||||
# set to same device. we don't care what device.
|
||||
|
||||
if not isinstance(tensor_1, list):
|
||||
tensor_1 = tensor_1.cpu().tolist()
|
||||
tensor_1 = tensor_1.tolist()
|
||||
if not isinstance(tensor_2, list):
|
||||
tensor_2 = tensor_2.cpu().tolist()
|
||||
tensor_2 = tensor_2.tolist()
|
||||
|
||||
in_order = len(tensor_1) <= len(tensor_2)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
|
@ -363,7 +363,7 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -406,7 +406,7 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -352,7 +352,7 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
|
||||
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
|
||||
self.assertEqual(logits[0].topk(2).indices.tolist(), expected_top2)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
|
@ -117,7 +117,7 @@ class EncodecModelTester:
|
||||
config.normalize = True
|
||||
|
||||
processor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sampling_rate)
|
||||
input_values = list(input_values.cpu().numpy())
|
||||
input_values = input_values.tolist()
|
||||
inputs_dict = processor(
|
||||
input_values, sampling_rate=config.sampling_rate, padding=True, return_tensors="pt"
|
||||
).to(torch_device)
|
||||
@ -495,7 +495,7 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
# use max bandwidth for best possible reconstruction
|
||||
encoder_outputs = model.encode(inputs["input_values"], bandwidth=float(bandwidth))
|
||||
|
||||
audio_code_sums = [a[0].sum().cpu().item() for a in encoder_outputs[0]]
|
||||
audio_code_sums = [a[0].sum().item() for a in encoder_outputs[0]]
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertListEqual(audio_code_sums, expected_codesums[bandwidth])
|
||||
@ -552,7 +552,7 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs = model.encode(
|
||||
inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth), return_dict=False
|
||||
)
|
||||
audio_code_sums = [a[0].sum().cpu().item() for a in encoder_outputs[0]]
|
||||
audio_code_sums = [a[0].sum().item() for a in encoder_outputs[0]]
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertListEqual(audio_code_sums, expected_codesums[bandwidth])
|
||||
@ -610,8 +610,8 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
# use max bandwidth for best possible reconstruction
|
||||
encoder_outputs = model.encode(input_values, bandwidth=float(bandwidth), return_dict=False)
|
||||
audio_code_sums_0 = [a[0][0].sum().cpu().item() for a in encoder_outputs[0]]
|
||||
audio_code_sums_1 = [a[0][1].sum().cpu().item() for a in encoder_outputs[0]]
|
||||
audio_code_sums_0 = [a[0][0].sum().item() for a in encoder_outputs[0]]
|
||||
audio_code_sums_1 = [a[0][1].sum().item() for a in encoder_outputs[0]]
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertListEqual(audio_code_sums_0, expected_codesums[bandwidth][0])
|
||||
|
@ -662,7 +662,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded, max_length=20)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
@ -724,7 +724,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded, max_length=20)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -552,7 +552,7 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -466,7 +466,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -540,7 +540,7 @@ class MimiIntegrationTest(unittest.TestCase):
|
||||
# use max bandwidth for best possible reconstruction
|
||||
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
|
||||
|
||||
audio_code_sums = encoder_outputs[0].sum().cpu().item()
|
||||
audio_code_sums = encoder_outputs[0].sum().item()
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
# assert relative difference less than a threshold, because `audio_code_sums` varies a bit
|
||||
|
@ -951,8 +951,8 @@ class MoshiIntegrationTests(unittest.TestCase):
|
||||
expected_text_token = 452
|
||||
expected_audio_tokens = [916, 1396, 1238, 579, 1105, 914, 1257, 810] # fmt: skip
|
||||
|
||||
self.assertTrue(expected_text_token == model_outputs.sequences[0, -2].cpu().item())
|
||||
self.assertTrue(expected_audio_tokens == model_outputs.audio_codes[0, :, -1].cpu().tolist())
|
||||
self.assertTrue(expected_text_token == model_outputs.sequences[0, -2].item())
|
||||
self.assertTrue(expected_audio_tokens == model_outputs.audio_codes[0, :, -1].tolist())
|
||||
|
||||
@slow
|
||||
def test_moshiko_greedy_unconditional_fp16_eager(self):
|
||||
@ -966,7 +966,7 @@ class MoshiIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# eager equivalence is not as strict as sdpa.
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
def test_moshiko_greedy_unconditional_fp32(self):
|
||||
@ -986,8 +986,8 @@ class MoshiIntegrationTests(unittest.TestCase):
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist())
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
@require_torch_fp16
|
||||
@ -1008,8 +1008,8 @@ class MoshiIntegrationTests(unittest.TestCase):
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist())
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
@require_torch_fp16
|
||||
@ -1030,5 +1030,5 @@ class MoshiIntegrationTests(unittest.TestCase):
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= 2048)
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist())
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
@ -486,7 +486,7 @@ class OPTGenerationTest(unittest.TestCase):
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
|
@ -989,7 +989,7 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
|
||||
outputs, text_labels=text_labels
|
||||
)
|
||||
|
||||
objects_labels = post_processed_output_with_text_labels[0]["labels"].cpu().tolist()
|
||||
objects_labels = post_processed_output_with_text_labels[0]["labels"].tolist()
|
||||
self.assertListEqual(objects_labels, [0, 0])
|
||||
|
||||
objects_text_labels = post_processed_output_with_text_labels[0]["text_labels"]
|
||||
|
@ -975,7 +975,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
outputs, text_labels=text_labels
|
||||
)
|
||||
|
||||
objects_labels = post_processed_output_with_text_labels[0]["labels"].cpu().tolist()
|
||||
objects_labels = post_processed_output_with_text_labels[0]["labels"].tolist()
|
||||
self.assertListEqual(objects_labels, [0, 0])
|
||||
|
||||
objects_text_labels = post_processed_output_with_text_labels[0]["text_labels"]
|
||||
|
@ -311,7 +311,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@ -379,7 +379,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@ -438,7 +438,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=n_docs,
|
||||
@ -507,7 +507,7 @@ class RagTestMixin:
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
n_docs=retriever_n_docs,
|
||||
@ -964,7 +964,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
docs_dict = retriever(
|
||||
input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt"
|
||||
input_ids.detach().cpu().numpy(), question_hidden_states.detach().cpu().numpy(), return_tensors="pt"
|
||||
)
|
||||
doc_scores = torch.bmm(
|
||||
question_hidden_states.unsqueeze(1),
|
||||
|
@ -1044,7 +1044,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
).loss
|
||||
|
||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
|
||||
self.assertTrue(loss.item() <= loss_more_masked.item())
|
||||
|
||||
def test_mask_feature_prob_ctc(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained(
|
||||
|
@ -670,7 +670,7 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
#
|
||||
# input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
# logits = model(input_values).logits
|
||||
# pred_ids = torch.argmax(logits, axis=-1).cpu().tolist()
|
||||
# pred_ids = torch.argmax(logits, axis=-1).tolist()
|
||||
# ```
|
||||
# fmt: off
|
||||
pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
|
@ -527,4 +527,4 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
] # the president the president the president the president the president the president the president the president the president the president
|
||||
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].cpu().numpy().tolist(), expected_output_ids)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
@ -4231,7 +4231,7 @@ class ModelTesterMixin:
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
|
||||
params = {name: param.grad.detach().to(device="cpu", copy=True) for name, param in model.named_parameters()}
|
||||
model.zero_grad()
|
||||
del loss
|
||||
|
||||
|
@ -2378,14 +2378,14 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
assert (mask_4d != 0).sum().item() == 0
|
||||
if 0 in mask_2d:
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
@ -2394,10 +2394,10 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
if 0 not in mask_2d:
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
if 0 in mask_2d:
|
||||
# at least causal mask + maybe more
|
||||
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() >= num_tokens_masked
|
||||
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
|
||||
|
||||
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
|
||||
@ -2415,15 +2415,15 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
elif not mask_converter.is_causal and context is None:
|
||||
assert (mask_4d != 0).sum().cpu().item() == 0
|
||||
assert (mask_4d != 0).sum().item() == 0
|
||||
elif mask_converter.is_causal and context is not None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
|
||||
num_tokens_masked = bsz * num_tokens_masked
|
||||
|
||||
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
|
||||
assert (mask_4d != 0).sum().item() == num_tokens_masked
|
||||
|
||||
def compute_num_context_mask(self, kv_len, context, q_len):
|
||||
# This function computes the # of attention tokens that are added for
|
||||
|
Loading…
Reference in New Issue
Block a user