Merge tensor operations with device transfer operations (#37097)

* Merge operations with to

Signed-off-by: cyy <cyyever@outlook.com>

* Use dtype

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever 2025-04-02 21:15:23 +08:00 committed by GitHub
parent c94c6ed397
commit 764ab0d46a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 209 additions and 113 deletions

View File

@ -724,7 +724,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
advance_state.reset(pre_seq.tolist())
if not advance_state.completed:
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
for advance_token in advance_tokens:
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
new_state = advance_state.copy(stateful=True)
@ -775,14 +775,14 @@ class ConstrainedBeamSearchScorer(BeamScorer):
track_new["new_states"].append(advance_state)
if len(track_new["new_indices"]) > 0:
new_indices = torch.tensor(track_new["new_indices"]).to(device)
new_indices = torch.tensor(track_new["new_indices"], device=device)
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
new_scores = torch.stack(track_new["new_scores"]).to(device)
all_states = topk_contraint_states + track_new["new_states"]
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
zipped = all_banks * 100 + all_scores
indices = zipped.sort(descending=True).indices

View File

@ -719,7 +719,9 @@ class AssistantToTargetTranslator:
"""
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
target_logits: torch.FloatTensor = torch.full(
target_shape, self.FILTER_VALUE, device=self._assistant_model_device
)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices

View File

@ -1157,7 +1157,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float, device=scores.device)
for sequence_ids, bias in self.sequence_bias.items():
if len(sequence_ids) == 1:
self.length_1_bias[sequence_ids[-1]] = bias

View File

@ -2599,7 +2599,7 @@ class GenerationMixin:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then

View File

@ -64,7 +64,7 @@ def _compute_default_rope_parameters(
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
return inv_freq, attention_factor
@ -156,7 +156,7 @@ def _compute_dynamic_ntk_parameters(
# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
return inv_freq, attention_factor
@ -241,14 +241,14 @@ def _compute_yarn_parameters(
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor

View File

@ -783,7 +783,9 @@ class AriaTextRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -173,7 +173,9 @@ class BambaRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -897,7 +897,7 @@ class BarkSemanticModel(BarkCausalModel):
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
# (except to get the input seq_len - that's why we keep the first 257 tokens)
semantic_output = super().generate(
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
input_embeds=input_embeds,
logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
generation_config=semantic_generation_config,
@ -989,8 +989,8 @@ class BarkCoarseModel(BarkCausalModel):
else:
# shape: (batch_size, 0)
x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
return x_semantic_history, x_coarse_history
@ -1097,7 +1097,7 @@ class BarkCoarseModel(BarkCausalModel):
input_coarse = torch.hstack(
[
input_coarse,
torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size).to(self.device),
torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
x_coarse[:, -max_coarse_history:],
]
)

View File

@ -1198,7 +1198,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
@ -1424,7 +1424,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
@ -1439,7 +1439,9 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
question_embeds = question_outputs[0]
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
question_attention_mask = torch.ones(
question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
)
bos_ids = torch.full(
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device

View File

@ -2498,7 +2498,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
if use_image_text_matching_head:
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
query_embeds = self.embeddings(

View File

@ -1158,7 +1158,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
subsequent_token_logits = subsequent_token_logits.masked_fill(
invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
)
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
subsequent_token_logits = subsequent_token_logits.masked_fill(
self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
)
@ -1287,13 +1287,13 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
batch_size, max_seq_length = attention_mask.shape
device = attention_mask.device
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
mask = bbox_first_token_mask.view(-1)
bbox_first_token_mask = torch.cat(
[
~bbox_first_token_mask,
torch.zeros([batch_size, 1], dtype=torch.bool).to(device),
torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
],
axis=1,
)

View File

@ -95,7 +95,10 @@ class ChameleonRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
@ -138,7 +141,8 @@ class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation

View File

@ -300,7 +300,7 @@ class DisentangledSelfAttention(nn.Module):
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
att_span = self.pos_ebd_size
relative_pos = relative_pos.long().to(query_layer.device)
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key:

View File

@ -233,7 +233,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:

View File

@ -113,7 +113,9 @@ class DeepseekV3RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -237,7 +237,7 @@ def prepare_coco_panoptic_annotation(
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
if "segments_info" in target:
masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device)
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
masks = rgb_to_id(masks)
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)

View File

@ -73,7 +73,10 @@ class OpenLlamaRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
@ -135,7 +138,10 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
inv_freq = 1.0 / (
base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

View File

@ -254,7 +254,7 @@ def prepare_coco_panoptic_annotation(
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
if "segments_info" in target:
masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device)
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
masks = rgb_to_id(masks)
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)

View File

@ -675,7 +675,9 @@ class DiffLlamaRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -1256,7 +1256,9 @@ class Emu3RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -158,7 +158,9 @@ class FalconRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -149,7 +149,9 @@ class GemmaRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -396,7 +396,9 @@ class Gemma2RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -191,7 +191,9 @@ class Gemma3RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -313,7 +313,9 @@ class GlmRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -243,7 +243,7 @@ class GPT2Attention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:

View File

@ -219,7 +219,7 @@ class GPTNeoSelfAttention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: # no matter the length, we just slice it

View File

@ -337,7 +337,9 @@ class GPTNeoXRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -281,7 +281,9 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -367,7 +367,9 @@ class GraniteRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -211,7 +211,9 @@ class GraniteMoeRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -801,7 +801,9 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -132,7 +132,9 @@ class HeliumRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -651,7 +651,7 @@ class SymmetricQuantFunction(Function):
Returns:
`torch.Tensor`: Symmetric-quantized value of *input*.
"""
zero_point = torch.tensor(0.0).to(scale.device)
zero_point = torch.tensor(0.0, device=scale.device)
n = 2 ** (k - 1) - 1
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)

View File

@ -415,7 +415,10 @@ class IdeficsEmbedding(torch.nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.

View File

@ -247,7 +247,7 @@ class ImageGPTAttention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
@ -297,7 +297,7 @@ class ImageGPTAttention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:

View File

@ -443,7 +443,9 @@ class JetMoeRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -138,7 +138,9 @@ class LlamaRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -412,7 +412,9 @@ class MimiRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -321,7 +321,9 @@ class MistralRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -443,7 +443,9 @@ class MixtralRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -295,7 +295,9 @@ class ModernBertRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -366,7 +366,9 @@ class MoonshineRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -356,7 +356,9 @@ class MoshiRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -143,7 +143,9 @@ class NemotronRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -332,7 +332,9 @@ class OlmoRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -333,7 +333,9 @@ class Olmo2RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -206,7 +206,9 @@ class OlmoeRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -1315,7 +1315,7 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
# [batch_size, height*width, channels]
new_vision_features = torch.cat(new_vision_features, 1)
new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64).to(vision_features[0].device)
new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device)
level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1]))
return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index
@ -1330,7 +1330,9 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
)
predicted_class_features = self.encoder_vision_features(
torch.where(
valid_mask, vision_features, torch.tensor(0.0, dtype=vision_features.dtype).to(vision_features.device)
valid_mask,
vision_features,
torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device),
)
)

View File

@ -113,7 +113,9 @@ class PersimmonRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -328,7 +328,9 @@ class PhiRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -334,7 +334,9 @@ class Qwen2RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -216,7 +216,9 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -361,7 +361,9 @@ class Qwen3RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -456,7 +456,9 @@ class Qwen3MoeRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -2873,7 +2873,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
)
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
else:
raise ValueError(
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
@ -3144,7 +3144,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
)
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
else:
raise ValueError(
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
@ -3420,7 +3420,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -3441,7 +3441,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
idx_most_probable_sequences_per_batch = (
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
idx_most_probable_sequences_per_batch
+ torch.arange(batch_size, device=self.device) * num_return_sequences
)
sequences = sequences[idx_most_probable_sequences_per_batch]
@ -3462,8 +3463,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
# Compute t2u decoder_input_ids
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
self.device
t2u_decoder_input_ids = torch.tensor(
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
)
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
# second generation
@ -3480,9 +3481,9 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)
@ -3748,7 +3749,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
text_decoder_input_ids = kwargs_text.get("decoder_input_ids")
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -3779,7 +3780,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
idx_most_probable_sequences_per_batch = (
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
idx_most_probable_sequences_per_batch
+ torch.arange(batch_size, device=self.device) * num_return_sequences
)
sequences = sequences[idx_most_probable_sequences_per_batch]
@ -3800,8 +3802,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
# Compute t2u decoder_input_ids
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
self.device
t2u_decoder_input_ids = torch.tensor(
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
)
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
@ -3819,9 +3821,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)
@ -4171,7 +4173,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
if tgt_lang is not None:
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -4221,7 +4223,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
idx_most_probable_sequences_per_batch = (
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
idx_most_probable_sequences_per_batch
+ torch.arange(batch_size, device=self.device) * num_return_sequences
)
sequences = sequences[idx_most_probable_sequences_per_batch]
@ -4242,8 +4245,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
# Compute t2u decoder_input_ids
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
self.device
t2u_decoder_input_ids = torch.tensor(
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
)
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
@ -4261,9 +4264,9 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)

View File

@ -3153,7 +3153,7 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
)
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
else:
raise ValueError(
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
@ -3434,7 +3434,7 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin
)
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
else:
raise ValueError(
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
@ -3720,7 +3720,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -3810,9 +3810,9 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id
@ -4090,7 +4090,7 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix
text_decoder_input_ids = kwargs_text.get("decoder_input_ids")
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -4190,9 +4190,9 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id
@ -4559,7 +4559,7 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
if tgt_lang is not None:
# tgt_lang gets priority over decoder input ids
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
@ -4679,9 +4679,9 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
)
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
waveform, waveform_lengths = self.vocoder(
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id

View File

@ -586,7 +586,7 @@ class SegGptImageProcessor(BaseImageProcessor):
palette_tensor = None
palette = self.get_palette(num_labels) if num_labels is not None else None
if palette is not None:
palette_tensor = torch.tensor(palette).float().to(masks.device)
palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float)
_, num_channels, _, _ = masks.shape
palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels)

View File

@ -820,7 +820,7 @@ class DisentangledSelfAttention(nn.Module):
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
att_span = self.pos_ebd_size
relative_pos = relative_pos.long().to(query_layer.device)
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key:

View File

@ -431,7 +431,7 @@ class SpeechT5RelativePositionalEncoding(torch.nn.Module):
def forward(self, hidden_states):
seq_len = hidden_states.shape[1]
pos_seq = torch.arange(0, seq_len).long().to(hidden_states.device)
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_seq[pos_seq < -self.max_length] = -self.max_length

View File

@ -118,7 +118,9 @@ class StableLmRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -325,7 +325,9 @@ class Starcoder2RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -1243,8 +1243,8 @@ class TapasForQuestionAnswering(TapasPreTrainedModel):
if table_mask is None:
table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids))
# torch.FloatTensor[batch_size, seq_length]
input_mask_float = attention_mask.float().to(device)
table_mask_float = table_mask.float().to(device)
input_mask_float = attention_mask.to(device=device, dtype=torch.float)
table_mask_float = table_mask.to(device=device, dtype=torch.float)
# Mask for cells that exist in the table (i.e. that are not padding).
cell_mask, _ = reduce_mean(input_mask_float, cell_index)

View File

@ -270,7 +270,9 @@ class Zamba2RotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

View File

@ -2454,7 +2454,7 @@ class Trainer:
self.state.init_training_references(self, max_steps, num_train_epochs, trial)
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
tr_loss = torch.tensor(0.0, device=args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step

View File

@ -225,7 +225,7 @@ def distributed_broadcast_scalars(
device: Optional[torch.device] = torch.device("cuda"),
) -> torch.Tensor:
try:
tensorized_scalar = torch.tensor(scalars).to(device)
tensorized_scalar = torch.tensor(scalars, device=device)
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)

View File

@ -596,7 +596,7 @@ def is_torch_bf16_available_on_device(device):
return True
try:
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
x = torch.zeros(2, 2, dtype=torch.bfloat16, device=device)
_ = x @ x
except: # noqa: E722
# TODO: more precise exception matching, if possible.