mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge tensor operations with device transfer operations (#37097)
* Merge operations with to Signed-off-by: cyy <cyyever@outlook.com> * Use dtype Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
c94c6ed397
commit
764ab0d46a
@ -724,7 +724,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
advance_state.reset(pre_seq.tolist())
|
||||
|
||||
if not advance_state.completed:
|
||||
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
|
||||
advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
|
||||
for advance_token in advance_tokens:
|
||||
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
||||
new_state = advance_state.copy(stateful=True)
|
||||
@ -775,14 +775,14 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
track_new["new_states"].append(advance_state)
|
||||
|
||||
if len(track_new["new_indices"]) > 0:
|
||||
new_indices = torch.tensor(track_new["new_indices"]).to(device)
|
||||
new_indices = torch.tensor(track_new["new_indices"], device=device)
|
||||
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
|
||||
new_scores = torch.stack(track_new["new_scores"]).to(device)
|
||||
|
||||
all_states = topk_contraint_states + track_new["new_states"]
|
||||
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
|
||||
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
|
||||
all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
|
||||
all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
|
||||
|
||||
zipped = all_banks * 100 + all_scores
|
||||
indices = zipped.sort(descending=True).indices
|
||||
|
@ -719,7 +719,9 @@ class AssistantToTargetTranslator:
|
||||
"""
|
||||
|
||||
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
|
||||
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
|
||||
target_logits: torch.FloatTensor = torch.full(
|
||||
target_shape, self.FILTER_VALUE, device=self._assistant_model_device
|
||||
)
|
||||
# Mask for valid indices
|
||||
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
|
||||
# Exclude invalid indices
|
||||
|
@ -1157,7 +1157,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
|
||||
# with simpler logic.
|
||||
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
|
||||
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float, device=scores.device)
|
||||
for sequence_ids, bias in self.sequence_bias.items():
|
||||
if len(sequence_ids) == 1:
|
||||
self.length_1_bias[sequence_ids[-1]] = bias
|
||||
|
@ -2599,7 +2599,7 @@ class GenerationMixin:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
|
@ -64,7 +64,7 @@ def _compute_default_rope_parameters(
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ def _compute_dynamic_ntk_parameters(
|
||||
|
||||
# Compute the inverse frequencies
|
||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
@ -241,14 +241,14 @@ def _compute_yarn_parameters(
|
||||
|
||||
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||
|
||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
|
@ -783,7 +783,9 @@ class AriaTextRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -173,7 +173,9 @@ class BambaRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -897,7 +897,7 @@ class BarkSemanticModel(BarkCausalModel):
|
||||
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
|
||||
# (except to get the input seq_len - that's why we keep the first 257 tokens)
|
||||
semantic_output = super().generate(
|
||||
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
|
||||
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
|
||||
input_embeds=input_embeds,
|
||||
logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
|
||||
generation_config=semantic_generation_config,
|
||||
@ -989,8 +989,8 @@ class BarkCoarseModel(BarkCausalModel):
|
||||
|
||||
else:
|
||||
# shape: (batch_size, 0)
|
||||
x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
|
||||
x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
|
||||
x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
|
||||
x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
|
||||
|
||||
return x_semantic_history, x_coarse_history
|
||||
|
||||
@ -1097,7 +1097,7 @@ class BarkCoarseModel(BarkCausalModel):
|
||||
input_coarse = torch.hstack(
|
||||
[
|
||||
input_coarse,
|
||||
torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size).to(self.device),
|
||||
torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
|
||||
x_coarse[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
|
@ -1198,7 +1198,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if isinstance(input_ids, list):
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
@ -1424,7 +1424,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if isinstance(input_ids, list):
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
@ -1439,7 +1439,9 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
|
||||
question_embeds = question_outputs[0]
|
||||
|
||||
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
|
||||
question_attention_mask = torch.ones(
|
||||
question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
|
||||
)
|
||||
|
||||
bos_ids = torch.full(
|
||||
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
|
||||
|
@ -2498,7 +2498,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||
|
||||
if use_image_text_matching_head:
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device)
|
||||
attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
|
||||
|
||||
query_embeds = self.embeddings(
|
||||
|
@ -1158,7 +1158,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
subsequent_token_logits = subsequent_token_logits.masked_fill(
|
||||
invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
|
||||
)
|
||||
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
|
||||
subsequent_token_logits = subsequent_token_logits.masked_fill(
|
||||
self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
|
||||
)
|
||||
@ -1287,13 +1287,13 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
batch_size, max_seq_length = attention_mask.shape
|
||||
device = attention_mask.device
|
||||
|
||||
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
|
||||
|
||||
mask = bbox_first_token_mask.view(-1)
|
||||
bbox_first_token_mask = torch.cat(
|
||||
[
|
||||
~bbox_first_token_mask,
|
||||
torch.zeros([batch_size, 1], dtype=torch.bool).to(device),
|
||||
torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
@ -95,7 +95,10 @@ class ChameleonRotaryEmbedding(nn.Module):
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
@ -138,7 +141,8 @@ class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
|
||||
base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||
|
||||
|
@ -300,7 +300,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
|
||||
|
||||
att_span = self.pos_ebd_size
|
||||
relative_pos = relative_pos.long().to(query_layer.device)
|
||||
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
|
||||
|
||||
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||
if self.share_att_key:
|
||||
|
@ -233,7 +233,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -113,7 +113,9 @@ class DeepseekV3RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -237,7 +237,7 @@ def prepare_coco_panoptic_annotation(
|
||||
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
||||
|
||||
if "segments_info" in target:
|
||||
masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device)
|
||||
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
|
||||
masks = rgb_to_id(masks)
|
||||
|
||||
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
|
||||
|
@ -73,7 +73,10 @@ class OpenLlamaRotaryEmbedding(nn.Module):
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
@ -135,7 +138,10 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
@ -254,7 +254,7 @@ def prepare_coco_panoptic_annotation(
|
||||
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
||||
|
||||
if "segments_info" in target:
|
||||
masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device)
|
||||
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
|
||||
masks = rgb_to_id(masks)
|
||||
|
||||
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
|
||||
|
@ -675,7 +675,9 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -1256,7 +1256,9 @@ class Emu3RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -158,7 +158,9 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -149,7 +149,9 @@ class GemmaRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -396,7 +396,9 @@ class Gemma2RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -191,7 +191,9 @@ class Gemma3RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -313,7 +313,9 @@ class GlmRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -243,7 +243,7 @@ class GPT2Attention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -219,7 +219,7 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
|
@ -337,7 +337,9 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -281,7 +281,9 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -367,7 +367,9 @@ class GraniteRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -211,7 +211,9 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -801,7 +801,9 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -132,7 +132,9 @@ class HeliumRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -651,7 +651,7 @@ class SymmetricQuantFunction(Function):
|
||||
Returns:
|
||||
`torch.Tensor`: Symmetric-quantized value of *input*.
|
||||
"""
|
||||
zero_point = torch.tensor(0.0).to(scale.device)
|
||||
zero_point = torch.tensor(0.0, device=scale.device)
|
||||
|
||||
n = 2 ** (k - 1) - 1
|
||||
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
|
||||
|
@ -415,7 +415,10 @@ class IdeficsEmbedding(torch.nn.Module):
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
|
@ -247,7 +247,7 @@ class ImageGPTAttention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -297,7 +297,7 @@ class ImageGPTAttention(nn.Module):
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -443,7 +443,9 @@ class JetMoeRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -138,7 +138,9 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -412,7 +412,9 @@ class MimiRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -321,7 +321,9 @@ class MistralRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -443,7 +443,9 @@ class MixtralRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -295,7 +295,9 @@ class ModernBertRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -366,7 +366,9 @@ class MoonshineRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -356,7 +356,9 @@ class MoshiRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -143,7 +143,9 @@ class NemotronRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -332,7 +332,9 @@ class OlmoRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -333,7 +333,9 @@ class Olmo2RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -206,7 +206,9 @@ class OlmoeRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -1315,7 +1315,7 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
|
||||
|
||||
# [batch_size, height*width, channels]
|
||||
new_vision_features = torch.cat(new_vision_features, 1)
|
||||
new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64).to(vision_features[0].device)
|
||||
new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device)
|
||||
level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1]))
|
||||
|
||||
return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index
|
||||
@ -1330,7 +1330,9 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
|
||||
)
|
||||
predicted_class_features = self.encoder_vision_features(
|
||||
torch.where(
|
||||
valid_mask, vision_features, torch.tensor(0.0, dtype=vision_features.dtype).to(vision_features.device)
|
||||
valid_mask,
|
||||
vision_features,
|
||||
torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -113,7 +113,9 @@ class PersimmonRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -328,7 +328,9 @@ class PhiRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -334,7 +334,9 @@ class Qwen2RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -216,7 +216,9 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -361,7 +361,9 @@ class Qwen3RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -456,7 +456,9 @@ class Qwen3MoeRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -2873,7 +2873,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
|
||||
@ -3144,7 +3144,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
|
||||
@ -3420,7 +3420,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
|
||||
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -3441,7 +3441,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
|
||||
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
|
||||
idx_most_probable_sequences_per_batch = (
|
||||
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
|
||||
idx_most_probable_sequences_per_batch
|
||||
+ torch.arange(batch_size, device=self.device) * num_return_sequences
|
||||
)
|
||||
sequences = sequences[idx_most_probable_sequences_per_batch]
|
||||
|
||||
@ -3462,8 +3463,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
# Compute t2u decoder_input_ids
|
||||
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
|
||||
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
|
||||
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
|
||||
self.device
|
||||
t2u_decoder_input_ids = torch.tensor(
|
||||
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
|
||||
)
|
||||
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
||||
# second generation
|
||||
@ -3480,9 +3481,9 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)
|
||||
|
||||
@ -3748,7 +3749,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
text_decoder_input_ids = kwargs_text.get("decoder_input_ids")
|
||||
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -3779,7 +3780,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
|
||||
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
|
||||
idx_most_probable_sequences_per_batch = (
|
||||
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
|
||||
idx_most_probable_sequences_per_batch
|
||||
+ torch.arange(batch_size, device=self.device) * num_return_sequences
|
||||
)
|
||||
sequences = sequences[idx_most_probable_sequences_per_batch]
|
||||
|
||||
@ -3800,8 +3802,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
# Compute t2u decoder_input_ids
|
||||
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
|
||||
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
|
||||
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
|
||||
self.device
|
||||
t2u_decoder_input_ids = torch.tensor(
|
||||
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
|
||||
)
|
||||
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
||||
|
||||
@ -3819,9 +3821,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)
|
||||
|
||||
@ -4171,7 +4173,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
if tgt_lang is not None:
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -4221,7 +4223,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1)
|
||||
idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1)
|
||||
idx_most_probable_sequences_per_batch = (
|
||||
idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences
|
||||
idx_most_probable_sequences_per_batch
|
||||
+ torch.arange(batch_size, device=self.device) * num_return_sequences
|
||||
)
|
||||
sequences = sequences[idx_most_probable_sequences_per_batch]
|
||||
|
||||
@ -4242,8 +4245,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
# Compute t2u decoder_input_ids
|
||||
t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids")
|
||||
t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang)
|
||||
t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to(
|
||||
self.device
|
||||
t2u_decoder_input_ids = torch.tensor(
|
||||
[[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device
|
||||
)
|
||||
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
||||
|
||||
@ -4261,9 +4264,9 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device)
|
||||
spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id)
|
||||
|
||||
|
@ -3153,7 +3153,7 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
|
||||
@ -3434,7 +3434,7 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin
|
||||
)
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps
|
||||
@ -3720,7 +3720,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
|
||||
|
||||
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -3810,9 +3810,9 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(
|
||||
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id
|
||||
@ -4090,7 +4090,7 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix
|
||||
text_decoder_input_ids = kwargs_text.get("decoder_input_ids")
|
||||
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -4190,9 +4190,9 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(
|
||||
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id
|
||||
@ -4559,7 +4559,7 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
|
||||
if tgt_lang is not None:
|
||||
# tgt_lang gets priority over decoder input ids
|
||||
text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device)
|
||||
text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device)
|
||||
|
||||
kwargs_text["decoder_input_ids"] = text_decoder_input_ids
|
||||
|
||||
@ -4679,9 +4679,9 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device)
|
||||
vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device)
|
||||
speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device)
|
||||
|
||||
waveform, waveform_lengths = self.vocoder(
|
||||
input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id
|
||||
|
@ -586,7 +586,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
palette_tensor = None
|
||||
palette = self.get_palette(num_labels) if num_labels is not None else None
|
||||
if palette is not None:
|
||||
palette_tensor = torch.tensor(palette).float().to(masks.device)
|
||||
palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float)
|
||||
_, num_channels, _, _ = masks.shape
|
||||
palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels)
|
||||
|
||||
|
@ -820,7 +820,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
|
||||
|
||||
att_span = self.pos_ebd_size
|
||||
relative_pos = relative_pos.long().to(query_layer.device)
|
||||
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
|
||||
|
||||
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||
if self.share_att_key:
|
||||
|
@ -431,7 +431,7 @@ class SpeechT5RelativePositionalEncoding(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
seq_len = hidden_states.shape[1]
|
||||
pos_seq = torch.arange(0, seq_len).long().to(hidden_states.device)
|
||||
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
|
||||
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
||||
|
||||
pos_seq[pos_seq < -self.max_length] = -self.max_length
|
||||
|
@ -118,7 +118,9 @@ class StableLmRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -325,7 +325,9 @@ class Starcoder2RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -1243,8 +1243,8 @@ class TapasForQuestionAnswering(TapasPreTrainedModel):
|
||||
if table_mask is None:
|
||||
table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids))
|
||||
# torch.FloatTensor[batch_size, seq_length]
|
||||
input_mask_float = attention_mask.float().to(device)
|
||||
table_mask_float = table_mask.float().to(device)
|
||||
input_mask_float = attention_mask.to(device=device, dtype=torch.float)
|
||||
table_mask_float = table_mask.to(device=device, dtype=torch.float)
|
||||
# Mask for cells that exist in the table (i.e. that are not padding).
|
||||
cell_mask, _ = reduce_mean(input_mask_float, cell_index)
|
||||
|
||||
|
@ -270,7 +270,9 @@ class Zamba2RotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
@ -2454,7 +2454,7 @@ class Trainer:
|
||||
self.state.init_training_references(self, max_steps, num_train_epochs, trial)
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(args.device)
|
||||
tr_loss = torch.tensor(0.0, device=args.device)
|
||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
||||
self._total_loss_scalar = 0.0
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
@ -225,7 +225,7 @@ def distributed_broadcast_scalars(
|
||||
device: Optional[torch.device] = torch.device("cuda"),
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).to(device)
|
||||
tensorized_scalar = torch.tensor(scalars, device=device)
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
@ -596,7 +596,7 @@ def is_torch_bf16_available_on_device(device):
|
||||
return True
|
||||
|
||||
try:
|
||||
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
|
||||
x = torch.zeros(2, 2, dtype=torch.bfloat16, device=device)
|
||||
_ = x @ x
|
||||
except: # noqa: E722
|
||||
# TODO: more precise exception matching, if possible.
|
||||
|
Loading…
Reference in New Issue
Block a user