mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
[CSM] infer codec model with no_grad + audio eos label (#38215)
* infer codec model with no_grad * codec_model eval * training labels: add audio eos token
This commit is contained in:
parent
10ae443ec0
commit
3142bd8592
@ -315,6 +315,7 @@ device = "cuda"
|
|||||||
processor = AutoProcessor.from_pretrained(model_id)
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||||
model.train()
|
model.train()
|
||||||
|
model.codec_model.eval()
|
||||||
|
|
||||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||||
# ensure the audio is 24kHz
|
# ensure the audio is 24kHz
|
||||||
|
@ -981,6 +981,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
# =======================================
|
# =======================================
|
||||||
# TODO: @eustlb, this should be batched !!!
|
# TODO: @eustlb, this should be batched !!!
|
||||||
# but requires making sure batched inference of the codec model works as intended
|
# but requires making sure batched inference of the codec model works as intended
|
||||||
|
with torch.no_grad():
|
||||||
audio_tokens_list = []
|
audio_tokens_list = []
|
||||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||||
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||||
|
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||||
# mask depth decoder
|
# mask depth decoder
|
||||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||||
|
@ -595,6 +595,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
# =======================================
|
# =======================================
|
||||||
# TODO: @eustlb, this should be batched !!!
|
# TODO: @eustlb, this should be batched !!!
|
||||||
# but requires making sure batched inference of the codec model works as intended
|
# but requires making sure batched inference of the codec model works as intended
|
||||||
|
with torch.no_grad():
|
||||||
audio_tokens_list = []
|
audio_tokens_list = []
|
||||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||||
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||||
|
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||||
# mask depth decoder
|
# mask depth decoder
|
||||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||||
|
@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
|
|||||||
else:
|
else:
|
||||||
skip_frames_idxs = audio_frame_idxs
|
skip_frames_idxs = audio_frame_idxs
|
||||||
|
|
||||||
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
|
labels = torch.where(
|
||||||
|
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
|
||||||
|
data["input_ids"],
|
||||||
|
-100,
|
||||||
|
)
|
||||||
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
||||||
|
|
||||||
data["labels"] = labels
|
data["labels"] = labels
|
||||||
|
Loading…
Reference in New Issue
Block a user