mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 03:58:25 +06:00
[CSM] infer codec model with no_grad + audio eos label (#38215)
* infer codec model with no_grad * codec_model eval * training labels: add audio eos token
This commit is contained in:
parent
10ae443ec0
commit
3142bd8592
@ -315,6 +315,7 @@ device = "cuda"
|
|||||||
processor = AutoProcessor.from_pretrained(model_id)
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||||
model.train()
|
model.train()
|
||||||
|
model.codec_model.eval()
|
||||||
|
|
||||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||||
# ensure the audio is 24kHz
|
# ensure the audio is 24kHz
|
||||||
|
@ -981,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
# =======================================
|
# =======================================
|
||||||
# TODO: @eustlb, this should be batched !!!
|
# TODO: @eustlb, this should be batched !!!
|
||||||
# but requires making sure batched inference of the codec model works as intended
|
# but requires making sure batched inference of the codec model works as intended
|
||||||
audio_tokens_list = []
|
with torch.no_grad():
|
||||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
audio_tokens_list = []
|
||||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||||
start_idx = batch_input_values_cutoffs[i]
|
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||||
end_idx = batch_input_values_cutoffs[i + 1]
|
start_idx = batch_input_values_cutoffs[i]
|
||||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
end_idx = batch_input_values_cutoffs[i + 1]
|
||||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||||
audio_tokens_list.append(codebook_ids[0])
|
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||||
|
audio_tokens_list.append(codebook_ids[0])
|
||||||
|
|
||||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||||
batched_audio_token_ids = torch.stack(
|
batched_audio_token_ids = torch.stack(
|
||||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||||
)
|
)
|
||||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||||
# =======================================
|
# =======================================
|
||||||
audio_token_id = self.config.audio_token_id
|
audio_token_id = self.config.audio_token_id
|
||||||
audio_token_mask = input_ids == audio_token_id
|
audio_token_mask = input_ids == audio_token_id
|
||||||
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||||
|
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||||
# mask depth decoder
|
# mask depth decoder
|
||||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||||
|
@ -595,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
# =======================================
|
# =======================================
|
||||||
# TODO: @eustlb, this should be batched !!!
|
# TODO: @eustlb, this should be batched !!!
|
||||||
# but requires making sure batched inference of the codec model works as intended
|
# but requires making sure batched inference of the codec model works as intended
|
||||||
audio_tokens_list = []
|
with torch.no_grad():
|
||||||
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
audio_tokens_list = []
|
||||||
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||||||
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||||||
start_idx = batch_input_values_cutoffs[i]
|
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||||||
end_idx = batch_input_values_cutoffs[i + 1]
|
start_idx = batch_input_values_cutoffs[i]
|
||||||
audio_batch = batch_input_values[..., start_idx:end_idx]
|
end_idx = batch_input_values_cutoffs[i + 1]
|
||||||
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||||||
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||||||
audio_tokens_list.append(codebook_ids[0])
|
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||||||
|
audio_tokens_list.append(codebook_ids[0])
|
||||||
|
|
||||||
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||||||
batched_audio_token_ids = torch.stack(
|
batched_audio_token_ids = torch.stack(
|
||||||
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||||||
)
|
)
|
||||||
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||||||
# =======================================
|
# =======================================
|
||||||
audio_token_id = self.config.audio_token_id
|
audio_token_id = self.config.audio_token_id
|
||||||
audio_token_mask = input_ids == audio_token_id
|
audio_token_mask = input_ids == audio_token_id
|
||||||
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||||||
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||||||
|
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||||||
# mask depth decoder
|
# mask depth decoder
|
||||||
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||||||
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||||||
|
@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
|
|||||||
else:
|
else:
|
||||||
skip_frames_idxs = audio_frame_idxs
|
skip_frames_idxs = audio_frame_idxs
|
||||||
|
|
||||||
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
|
labels = torch.where(
|
||||||
|
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
|
||||||
|
data["input_ids"],
|
||||||
|
-100,
|
||||||
|
)
|
||||||
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
||||||
|
|
||||||
data["labels"] = labels
|
data["labels"] = labels
|
||||||
|
Loading…
Reference in New Issue
Block a user