mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[CSM] infer codec model with no_grad + audio eos label (#38215)
* infer codec model with no_grad * codec_model eval * training labels: add audio eos token
This commit is contained in:
parent
10ae443ec0
commit
3142bd8592
@ -315,6 +315,7 @@ device = "cuda"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
model.train()
|
||||
model.codec_model.eval()
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
# ensure the audio is 24kHz
|
||||
|
@ -981,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
# =======================================
|
||||
# TODO: @eustlb, this should be batched !!!
|
||||
# but requires making sure batched inference of the codec model works as intended
|
||||
audio_tokens_list = []
|
||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||
start_idx = batch_input_values_cutoffs[i]
|
||||
end_idx = batch_input_values_cutoffs[i + 1]
|
||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||
audio_tokens_list.append(codebook_ids[0])
|
||||
with torch.no_grad():
|
||||
audio_tokens_list = []
|
||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||
start_idx = batch_input_values_cutoffs[i]
|
||||
end_idx = batch_input_values_cutoffs[i + 1]
|
||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||
audio_tokens_list.append(codebook_ids[0])
|
||||
|
||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||
batched_audio_token_ids = torch.stack(
|
||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||
)
|
||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||
batched_audio_token_ids = torch.stack(
|
||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||
)
|
||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||
# =======================================
|
||||
audio_token_id = self.config.audio_token_id
|
||||
audio_token_mask = input_ids == audio_token_id
|
||||
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
if labels is not None:
|
||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||
# mask depth decoder
|
||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||
|
@ -595,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
# =======================================
|
||||
# TODO: @eustlb, this should be batched !!!
|
||||
# but requires making sure batched inference of the codec model works as intended
|
||||
audio_tokens_list = []
|
||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||
start_idx = batch_input_values_cutoffs[i]
|
||||
end_idx = batch_input_values_cutoffs[i + 1]
|
||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||
audio_tokens_list.append(codebook_ids[0])
|
||||
with torch.no_grad():
|
||||
audio_tokens_list = []
|
||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||
start_idx = batch_input_values_cutoffs[i]
|
||||
end_idx = batch_input_values_cutoffs[i + 1]
|
||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||
audio_tokens_list.append(codebook_ids[0])
|
||||
|
||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||
batched_audio_token_ids = torch.stack(
|
||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||
)
|
||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||
batched_audio_token_ids = torch.stack(
|
||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||
)
|
||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||
# =======================================
|
||||
audio_token_id = self.config.audio_token_id
|
||||
audio_token_mask = input_ids == audio_token_id
|
||||
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
if labels is not None:
|
||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||
# mask depth decoder
|
||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||
|
@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
|
||||
else:
|
||||
skip_frames_idxs = audio_frame_idxs
|
||||
|
||||
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
|
||||
labels = torch.where(
|
||||
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
|
||||
data["input_ids"],
|
||||
-100,
|
||||
)
|
||||
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
||||
|
||||
data["labels"] = labels
|
||||
|
Loading…
Reference in New Issue
Block a user