mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Whisper] patch float type on mps (#35295)
* fix float type on mps * make
This commit is contained in:
parent
d5b81e1ca1
commit
9feae5fb01
@ -632,7 +632,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
cur_bsz=cur_bsz,
|
||||
batch_idx_map=batch_idx_map,
|
||||
)
|
||||
time_offset = seek.to(torch.float64) * time_precision / input_stride
|
||||
time_offset = (
|
||||
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
|
||||
)
|
||||
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
||||
|
||||
# 6.2 cut out next 30s segment from input features
|
||||
@ -1805,6 +1807,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
timestamp_segment_indices.add_(1)
|
||||
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
|
||||
device = seek_sequence.device
|
||||
|
||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
||||
# "end of segment" prediction and slice the decoding into segments accordingly
|
||||
@ -1828,8 +1831,12 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
|
||||
segments.append(
|
||||
{
|
||||
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
|
||||
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
|
||||
"start": time_offset[prev_idx]
|
||||
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
|
||||
* time_precision,
|
||||
"end": time_offset[prev_idx]
|
||||
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
|
||||
* time_precision,
|
||||
"tokens": sliced_tokens,
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
@ -1856,7 +1863,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
|
||||
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
|
||||
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
|
||||
torch.float32 if device.type == "mps" else torch.float64
|
||||
)
|
||||
segments = [
|
||||
{
|
||||
"start": time_offset[prev_idx],
|
||||
|
Loading…
Reference in New Issue
Block a user