[CSM] infer codec model with no_grad + audio eos label (#38215)

* infer codec model with no_grad

* codec_model eval

* training labels: add audio eos token
This commit is contained in:
eustlb 2025-05-27 16:10:17 +02:00 committed by GitHub
parent 10ae443ec0
commit 3142bd8592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 31 deletions

View File

@ -315,6 +315,7 @@ device = "cuda"
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
model.codec_model.eval()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz

View File

@ -981,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
with torch.no_grad():
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
# =======================================
audio_token_id = self.config.audio_token_id
audio_token_mask = input_ids == audio_token_id
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -595,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
with torch.no_grad():
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
# =======================================
audio_token_id = self.config.audio_token_id
audio_token_mask = input_ids == audio_token_id
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
else:
skip_frames_idxs = audio_frame_idxs
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
labels = torch.where(
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
data["input_ids"],
-100,
)
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
data["labels"] = labels