diff --git a/docs/source/en/model_doc/csm.md b/docs/source/en/model_doc/csm.md index 2d916da161f..53c24a5eba5 100644 --- a/docs/source/en/model_doc/csm.md +++ b/docs/source/en/model_doc/csm.md @@ -315,6 +315,7 @@ device = "cuda" processor = AutoProcessor.from_pretrained(model_id) model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) model.train() +model.codec_model.eval() ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") # ensure the audio is 24kHz diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index e1f1d477b38..77716956544 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -981,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended - audio_tokens_list = [] - for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): - batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] - for i in range(batch_input_values_cutoffs.shape[0] - 1): - start_idx = batch_input_values_cutoffs[i] - end_idx = batch_input_values_cutoffs[i + 1] - audio_batch = batch_input_values[..., start_idx:end_idx] - codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) - codebook_ids = codec_outputs.audio_codes.transpose(1, -1) - audio_tokens_list.append(codebook_ids[0]) + with torch.no_grad(): + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) - max_audio_frames = max(el.shape[0] for el in audio_tokens_list) - batched_audio_token_ids = torch.stack( - [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] - ) - audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) # ======================================= audio_token_id = self.config.audio_token_id audio_token_mask = input_ids == audio_token_id @@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): if labels is not None: labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids # mask depth decoder depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index aab2d131c45..86483076d30 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -595,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended - audio_tokens_list = [] - for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): - batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] - for i in range(batch_input_values_cutoffs.shape[0] - 1): - start_idx = batch_input_values_cutoffs[i] - end_idx = batch_input_values_cutoffs[i + 1] - audio_batch = batch_input_values[..., start_idx:end_idx] - codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) - codebook_ids = codec_outputs.audio_codes.transpose(1, -1) - audio_tokens_list.append(codebook_ids[0]) + with torch.no_grad(): + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) - max_audio_frames = max(el.shape[0] for el in audio_tokens_list) - batched_audio_token_ids = torch.stack( - [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] - ) - audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) # ======================================= audio_token_id = self.config.audio_token_id audio_token_mask = input_ids == audio_token_id @@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): if labels is not None: labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids # mask depth decoder depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py index 486c5eda4c7..a0f91a1c3df 100644 --- a/src/transformers/models/csm/processing_csm.py +++ b/src/transformers/models/csm/processing_csm.py @@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin): else: skip_frames_idxs = audio_frame_idxs - labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100) + labels = torch.where( + (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id), + data["input_ids"], + -100, + ) labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101 data["labels"] = labels