Fine-Tuning TTS model codes explained

Fine-Tuning TTS model codes explained

How to Fine-Tune Audio LLMs from scratch?

Photo by Victor Freitas on Unsplash

Audio AI models are picking up quite fast. Which started off with just TTS models, now covers music generation, conversational AI, and more.

My new book “Model Context Protocol: Advanced AI Agents for Beginners” is live now.

Model Context Protocol: Advanced AI Agents for Beginners (Generative AI books)

While the base versions of these audio model is commendable, there is still room for improvement. This is very similar to what we have in the case of language models that generate text. The base versions are good but still require fine-tuning.

https://medium.com/media/cdefaa8d828341cc695db36f76cbfe3e/href

Fine tuning for text generation models has been explored in quite some depth in the last year. But this is not the case with audio models. Fine-Tuning is still a very new topic in the case of audio AI models. Today, we would be exploring how to fine-tune a TTS model from scratch using Unsloth.

We will be fine-tuning SESAME CSM-1B, the conversational AI model with one billion parameters, on Google Colab for free.

https://medium.com/media/c0b839a7ce779c2e10611f79bf08fb21/href

So let’s get started.

  1. We will start off with pip installing some required libraries like Unsloth, Transformers, Accelerate etc.
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
!pip install unsloth
else:
# Do this only in Colab notebooks! Otherwise use pip install unsloth
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install transformers==4.52.3

2. Next, we will load the pre-trained CSM-1B model from Hugging Face.

from unsloth import FastModel
from transformers import CsmForConditionalGeneration
import torch

model, processor = FastModel.from_pretrained(
model_name = "unsloth/csm-1b",
max_seq_length= 2048, # Choose any for long context!
dtype = None, # Leave as None for auto-detection
auto_model = CsmForConditionalGeneration,
load_in_4bit = False, # Select True for 4bit - reduces memory usage
)

Why FastModel is used: It’s an Unsloth utility to load transformer models faster and with lower memory usage, especially optimized for fine-tuning and inference on limited hardware. It is very much similar to Hugging face .from_pretrained() function.

Hyperparameters:

  • model_name: which model to load,
  • max_seq_length: max input size the model will handle,
  • dtype: lets Unsloth auto-select best precision (like float16),
  • auto_model: model architecture to use under the hood,
  • load_in_4bit: whether to load a compressed 4-bit version to save memory.

3. Setting up LoRa config for fine tuning.

model = FastModel.get_peft_model(
model,
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)

What is LoRA: A technique that injects small trainable matrices into specific model layers, allowing you to fine-tune large models efficiently without updating all their weights.

https://medium.com/media/50a906b5bc0391e6451eda300860f73c/href

Hyperparameters:

  • r: rank of LoRA matrices (controls capacity),
  • target_modules: which layers get LoRA adapters,
  • lora_alpha: scaling factor for LoRA’s impact,
  • lora_dropout: dropout within LoRA (0 is fastest),
  • bias: whether to train bias terms,
  • use_gradient_checkpointing: saves memory during training (Unsloth version is extra efficient),
  • random_state: for reproducible training behavior,
  • use_rslora: toggle for rank-stabilized LoRA (off by default),

4. Preparing dataset

#@title Dataset Prep functions
from datasets import load_dataset, Audio, Dataset
import os
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("unsloth/csm-1b")

raw_ds = load_dataset("MrDragonFox/Elise", split="train")

# Getting the speaker id is important for multi-speaker models and speaker consistency
speaker_key = "source"
if "source" not in raw_ds.column_names and "speaker_id" not in raw_ds.column_names:
print("Unsloth: No speaker found, adding default "source" of 0 for all examples")
new_column = ["0"] * len(raw_ds)
raw_ds = raw_ds.add_column("source", new_column)
elif "source" not in raw_ds.column_names and "speaker_id" in raw_ds.column_names:
speaker_key = "speaker_id"

target_sampling_rate = 24000
raw_ds = raw_ds.cast_column("audio", Audio(sampling_rate=target_sampling_rate))

def preprocess_example(example):
conversation = [
{
"role": str(example[speaker_key]),
"content": [
{"type": "text", "text": example["text"]},
{"type": "audio", "path": example["audio"]["array"]},
],
}
]
try:
model_inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
output_labels=True,
text_kwargs = {
"padding": "max_length", # pad to the max_length
"max_length": 256, # this should be the max length of audio
"pad_to_multiple_of": 8,
"padding_side": "right",
},
audio_kwargs = {
"sampling_rate": 24_000,
"max_length": 240001, # max input_values length of the whole dataset
"padding": "max_length",
},
common_kwargs = {"return_tensors": "pt"},
)
except Exception as e:
print(f"Error processing example with text '{example['text'][:50]}...': {e}")
return None

required_keys = ["input_ids", "attention_mask", "labels", "input_values", "input_values_cutoffs"] processed_example = {}
# print(model_inputs.keys())
for key in required_keys:
if key not in model_inputs:
print(f"Warning: Required key '{key}' not found in processor output for example.")
return None

value = model_inputs[key][0] processed_example[key] = value


# Final check (optional but good)
if not all(isinstance(processed_example[key], torch.Tensor) for key in processed_example):
print(f"Error: Not all required keys are tensors in final processed example. Keys: {list(processed_example.keys())}")
return None

return processed_example

processed_ds = raw_ds.map(
preprocess_example,
remove_columns=raw_ds.column_names,
desc="Preprocessing dataset",
)

We’ll be working with the MrDragonFox/Elise dataset—originally built for training text-to-speech (TTS) systems. It expects your data to include:

  • text and audio for single-speaker setups
  • source, text, and audio if your model involves multiple speakers

If you’re using a custom dataset, you can tweak this part of the code — but the input fields need to match this structure, otherwise training will either break or silently underperform.

MrDragonFox/Elise · Datasets at Hugging Face

Data Preparation Task:

  • Load and normalize the dataset: Pulls the audio-text dataset (MrDragonFox/Elise), ensures all audio is resampled to 24kHz, and assigns a speaker ID if missing.
  • Define multi-modal input structure: Creates a synthetic conversation for each example combining speaker info, text, and audio into a format the model can understand.
  • Tokenize and process inputs: Uses AutoProcessor to tokenize text and process audio into tensors, applying padding and length constraints.
  • Filter and clean: Drops examples with missing fields or faulty formatting, ensuring all required tensors (input_ids, labels, input_values, etc.) are present and valid.
  • Map processed data: Applies the preprocess_example function to every item in the dataset, converting it into a ready-to-train format.

5. Setting training parameters

from transformers import TrainingArguments, Trainer
from unsloth import is_bfloat16_supported

trainer = Trainer(
model = model,
train_dataset = processed_ds,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01, # Turn this on if overfitting
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)

trainer_stats = trainer.train()

Hyperparameters

  • per_device_train_batch_size = 2: Number of samples processed per device in one step (kept small due to audio memory load).
  • gradient_accumulation_steps = 4: Delays optimizer step by 4 batches, simulating a batch size of 8 without using more memory.
  • max_steps = 60: Total number of training steps; stops training after this.
  • learning_rate = 2e-4: Controls how fast the model learns; higher than usual since LoRA is used.
  • fp16 / bf16: Enables mixed precision training for speed and memory savings; uses bf16 if hardware allows.
  • optim = “adamw_8bit”: Memory-efficient 8-bit optimizer suitable for fine-tuning.
  • weight_decay = 0.01: Regularization to prevent overfitting by penalizing large weights.
  • lr_scheduler_type = “linear”: Gradually lowers the learning rate over time.
  • warmup_steps = 5: Slowly increases learning rate during the first 5 steps to stabilize early training.
  • output_dir = “outputs”: Saves model checkpoints and logs here.
  • seed = 3407: Ensures consistent results across runs.

6. Start using the fine-tuned TTS

from IPython.display import Audio, display
import soundfile as sf

text = "We just finished fine tuning a text to speech model... and it's pretty good!"
speaker_id = 0
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
audio_values = model.generate(
**inputs,
max_new_tokens=125,
output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

This code takes a text prompt and generates synthetic speech using a fine-tuned text-to-speech model. It encodes the input, generates audio on the GPU, saves it as a .wav file, and plays it inline using IPython’s audio player.

If you wish to find any other TTS model, you can find the full codes and Unsloth’s documentation.

Text-to-Speech (TTS) Fine-tuning | Unsloth Documentation

hope you try out fine-tuning your TTS model and make it sound like you.


Fine-Tuning TTS model codes explained was originally published in Data Science in Your Pocket on Medium, where people are continuing the conversation by highlighting and responding to this story.

Share this article
0
Share
Shareable URL
Prev Post

Unsloth : The fastest way to Fine-Tune LLMs

Next Post

Generative AI for Beginners (Visual Course)

Read next
Subscribe to our newsletter
Get notified of the best deals on our Courses, Tools and Giveaways..