Diffusion for text generation
Generative AI is evolving way faster than anyone imagined. LLMs are improving at a rapid rate. Since last year, different researchers have started experimenting with different architectures for LLMs as well apart from Transformers like Mamba, Google’s Titans, Meta’s Large Concept Models, Byte Latent Transformers, etc.
https://medium.com/media/1e1514c10bd1c77c0b3555d0b4c8dfab/href
Another interesting architecture to replace Transformers for LLMs is recently introduced called LLDMs, a diffusion model trained from scratch under the pre-training and supervised finetuning (SFT) paradigm.
LLaDA uses Diffusion, a technique usually used for Image generation for text generation as well
Before jumping ahead
What is Diffusion?
Diffusion models are a class of generative models that gradually transform data from a simple distribution (like noise) to a complex distribution (like natural language text). This transformation is achieved through a series of small, incremental steps. The process involves two main phases: a forward process and a reverse process.

Forward Process: Gradually masks or corrupts the data until it becomes pure noise.
Reverse Process: Gradually reconstructs the original data from the noise by predicting and unmasking tokens step-by-step.
Large Language Diffusion Models (LLDMs)
- LLDMs combine the principles of diffusion models with the capabilities of large language models. They aim to generate text by learning the distribution of language through a diffusion process.
- LLaDA (Large Language Diffusion with Masking) is an example of such a model. It uses a forward data masking process and a reverse process to model distributions and generate text.
How LLaDA Works
Forward Process:
- The model starts with a sequence of tokens (words or characters) and gradually masks them independently until the sequence is fully masked. The masking probability increases linearly with time.
- For example, at time t=0, the sequence is fully observed (no tokens are masked). At time t=1, all tokens are masked.
Reverse Process:
- The reverse process aims to recover the original sequence from the fully masked sequence. It iteratively predicts and unmasks tokens, moving from t=1 back to t=0.
- A mask predictor, typically a Transformer model, is trained to predict the masked tokens based on the partially masked sequence.
Example
Imagine we have a simple sentence:
“The quick brown fox jumps over the lazy dog.”
Forward Process
Initialization:
- Start with the original sentence: “The quick brown fox jumps over the lazy dog.”
- At t=0, the sentence is fully observed (no tokens are masked).
Masking Tokens:
Gradually mask tokens independently with increasing probability as t approaches 1.
For simplicity, let’s assume we mask tokens at t=0.5 and t=1.
At t=0.5, randomly mask some tokens:
Original: “The quick brown fox jumps over the lazy dog.”
Masked: “The [MASK] brown [MASK] jumps over the [MASK] dog.”
At t=1, all tokens are masked:
Fully masked: “[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK].”
Reverse Process
Initialization:
- Start with a fully masked sequence: “[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK].”
Predicting Tokens:
- Use the trained mask predictor (a Transformer model) to predict the original tokens iteratively.
At t=1 (fully masked):
Predicted: “[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK].”
At t=0.5 (partially unmasked):
Predicted: “The [MASK] brown [MASK] jumps over the [MASK] dog.”
The model predicts the masked tokens based on the context provided by the unmasked tokens.
At t=0 (fully unmasked):
Predicted: “The quick brown fox jumps over the lazy dog.”
The model has successfully reconstructed the original sentence.
Training
Data Preparation:
- Prepare a large corpus of text data.
- For each training example, randomly sample a masking ratio ‘t’ and mask tokens accordingly.
Loss Calculation:
- Compute the cross-entropy loss only for the masked tokens.
- The loss function ensures that the model learns to predict the original tokens accurately.
Inference
Sampling:
- Start with a fully masked sequence.
- Use the trained model to iteratively predict and unmask tokens.
Example:
- Prompt: “The [MASK] brown [MASK] jumps over the [MASK] dog.”
- Sampling steps:
Predict “quick” for the first [MASK].
Predict “fox” for the second [MASK].
Predict “lazy” for the third [MASK].
Main stages of LLaDA

Pre-training:
- Masking: All tokens in the input text are masked independently with a random ratio t sampled from a uniform distribution U[0,1]. This means each token has a t chance of being replaced by a mask token.
- Mask Predictor: A neural network, typically a Transformer, is trained to predict the original tokens from the masked input. It learns to associate the context of unmasked tokens with the correct masked tokens.
- Remasking: After prediction, some of the predicted tokens may be re-masked to simulate the diffusion process and improve the model’s ability to handle variable-length sequences.
Supervised Fine-tuning (SFT):
- Prompt and Response: The model is fine-tuned on pairs of prompts and responses. Only the response tokens are subject to masking.
- Mask Predictor: The same mask predictor is used to predict the original response tokens from the masked response, given the prompt as context.
- Intermediate Steps: The model may go through several intermediate steps where it refines its predictions based on the prompt and the partially masked response.
Sampling:
- Diffusion Process: During sampling, LLaDA simulates a diffusion process that starts with a fully masked sequence (t=1) and gradually unmasks tokens to reach the original sequence (t=0).
- Simultaneous Prediction: At each step of the diffusion process, the model predicts all masked tokens simultaneously.
- Flexible Remasking Strategies: The model can employ different strategies to decide which tokens to remask at each step, allowing for a more flexible and context-aware generation of text.
Performance and metrics

- Strengths: LLaDA 8B performs well in mathematics, science, and Chinese tasks, often outperforming or being competitive with other models.
- Weaknesses: It shows lower performance in some general tasks and code-related benchmarks compared to other models like LLaMA3 8B and Qwen2.5 7B.
- Overall: LLaDA 8B is competitive across various benchmarks, with notable strengths in specific areas such as mathematics and Chinese tasks.
Concluding this long post
Large Language Diffusion Models (LLDMs) offer a fresh approach to text generation by adapting diffusion techniques traditionally used in image synthesis. LLaDA, a prime example, gradually masks and unmasks text through a structured diffusion process, enabling flexible and context-aware generation. While still evolving, LLDMs show promise in tasks like mathematics and multilingual applications. As research advances, diffusion-based models could complement or even challenge transformer-based LLMs in the future.
Large Language Diffusion Models (LLDMs): End of LLMs? was originally published in Data Science in your pocket on Medium, where people are continuing the conversation by highlighting and responding to this story.