
* draft structure * depth decoder with forward pre hook * full model forward draft * draft update * depth decoder update * ConversationalSpeechModelForCausalLM udpates * add generate * max length criteria small fix * udpate * updates * generation update * update in loss compute * conversion script * update for correct input embeddings * handle interleaved rope * update * update * update * support compile * update training * add doc * update doc * correct inits * ConversationalSpeechModel -> Csm * conf update * name update * tests CsmForCausalLMTest * convert use cached_file * conf + modeling updates * generate utils handle third dim shape * integration test * modeling + conf updates * common test handle more than 2 dims * add nested audio list utils * processing handle nested audio list * csm processing draft * mimi util * init updates * modular update * convert modular * processing update * csm tests update * generate tests handle third dim * generate utils handle third dim * propagate _get_initial_cache_position update * tied_weight_keys update + convert correctly * fix inputs_embeds * revert audio nested list * batch inference update + return audio * audio_utils update * processor update * some more integration tests * remove old test * porcessing output labels * improve * fix * update rope values with equivalent ones * conversion update * udpate tests * handle depth decoder generation config * remove default eos_token_id * make style * revert modeling_mimi * add default generation_config * remove sdpa since handled by default * make * fix conflict * fix conflicts * correct naming * correct imports * make * causal -> conditional naming * causal -> conditional naming * auto update * make * make * add doc * test update * fix weight init * audio tokens offsets as buffer * 4d mask in conditional class * make * doc update * fix causal mask * fix causal mask * doc update * doc update * add processor doc * update doc * fix 4d causal mask * update make_list_of_audio * do not default to mutable * remove duplicates * remove useless reset_parameters * use GradientCheckpointingLayer * use can_return_tuple * formatting * prepend placeholder in _sample * torch compile fix * some more fixies * convert modular * fix * default max_length in convert * handle depth decoder generation config correctly * clearer formulation * handle output_loading_info * handle softmax warning * add doc * propagate _get_initial_cache_position changes * generation in its own module * add processor tests * fix compile witu cuda graphs * fix compile with cuda graphs * add csm.md * include CSM loss * doc nit * doc nit * doc nit * Update docs/source/en/model_doc/csm.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add save_audio to processor * Update src/transformers/models/csm/modular_csm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * doc update * simplify audio_codes_mask computation * doc update * simplify loss computation * fix static cache test * fix * remove comment * simplify encoded length computation * use hf-internal-testing * doc update * cast to float before numpy * nit * mem efficient codebook head * nit * cat input values with cutoffs --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
11 KiB
Csm
Overview
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model released by Sesame. It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
Model Architecture: CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model Mimi, introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
The original csm-1b checkpoint is available under the Sesame organization on Hugging Face.

Usage Tips
Without Conversational Context
CSM can be used to simply generate speech from a text prompt:
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
inputs = processor(text, add_special_tokens=True).to(device)
# another equivalent way to prepare the inputs
conversation = [
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_without_context.wav")
With Conversational Context
CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# 1. context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
# 2. text prompt
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_with_context.wav")
Batched Inference
CSM supports batched inference!
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# here a batch with two prompts
conversation = [
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
],
},
],
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
],
}
],
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
Making The Model Go Brrr
CSM supports full-graph compilation with CUDA graphs!
import torch
import copy
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset
model_id = "eustlb/csm-1b"
device = "cuda"
# set logs to ensure no recompilation and graph breaks
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
model.generation_config.max_length = 250 # big enough to avoid recompilation
model.generation_config.max_new_tokens = None # would take precedence over max_length
model.generation_config.cache_implementation = "static"
model.depth_decoder.generation_config.cache_implementation = "static"
# generation kwargs
gen_kwargs = {
"do_sample": False,
"depth_decoder_do_sample": False,
"temperature": 1.0,
"depth_decoder_temperature": 1.0,
}
# Define a timing decorator
class TimerContext:
def __init__(self, name="Execution"):
self.name = name
self.start_event = None
self.end_event = None
def __enter__(self):
# Use CUDA events for more accurate GPU timing
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self
def __exit__(self, *args):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
print(f"{self.name} time: {elapsed_time:.4f} seconds")
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
{"type": "audio", "path": ds[1]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
],
},
]
padded_inputs_1 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("First generation - compiling and recording CUDA graphs...")
with TimerContext("First generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
print("\n" + "="*50)
print("Second generation - fast !!!")
with TimerContext("Second generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
# now with different inputs
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
{"type": "audio", "path": ds[2]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[3]["text"]},
{"type": "audio", "path": ds[3]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[4]["text"]},
],
},
]
padded_inputs_2 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("Generation with other inputs!")
with TimerContext("Generation with different inputs"):
_ = model.generate(**padded_inputs_2, **gen_kwargs)
print("="*50)
Training
CSM Transformers integration supports training!
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
output_labels=True,
).to(device)
out = model(**inputs)
out.loss.backward()
This model was contributed by Eustache Le Bihan. The original code can be found here.
CsmConfig
autodoc CsmConfig
CsmDepthDecoderConfig
autodoc CsmDepthDecoderConfig
CsmProcessor
autodoc CsmProcessor - call
CsmForConditionalGeneration
autodoc CsmForConditionalGeneration - forward - generate
CsmDepthDecoderForCausalLM
autodoc CsmDepthDecoderForCausalLM
CsmDepthDecoderModel
autodoc CsmDepthDecoderModel
CsmBackboneModel
autodoc CsmBackboneModel