[CSM] infer codec model with no_grad + audio eos label (#38215)

* infer codec model with no_grad

* codec_model eval

* training labels: add audio eos token
This commit is contained in:
eustlb 2025-05-27 16:10:17 +02:00 committed by GitHub
parent 10ae443ec0
commit 3142bd8592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 31 deletions

View File

@ -315,6 +315,7 @@ device = "cuda"
processor = AutoProcessor.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train() model.train()
model.codec_model.eval()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz # ensure the audio is 24kHz

View File

@ -981,6 +981,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# ======================================= # =======================================
# TODO: @eustlb, this should be batched !!! # TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended # but requires making sure batched inference of the codec model works as intended
with torch.no_grad():
audio_tokens_list = [] audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None: if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder # mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -595,6 +595,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# ======================================= # =======================================
# TODO: @eustlb, this should be batched !!! # TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended # but requires making sure batched inference of the codec model works as intended
with torch.no_grad():
audio_tokens_list = [] audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None: if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder # mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
else: else:
skip_frames_idxs = audio_frame_idxs skip_frames_idxs = audio_frame_idxs
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100) labels = torch.where(
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
data["input_ids"],
-100,
)
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101 labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
data["labels"] = labels data["labels"] = labels