This article serves to explain the Stable Diffusion [7] model and some of its implementation details. Note: The Stable Diffusion model consists of several blocks carefully engineered together in a large code-base. Towards maintaining ease of simplicity, this article skips over unnecessary lines of code or modifies the code in some way without altering its logic.

Architecture Details

Figure 1. The Stable Diffusion model consists of multiple stages across several blocks as shown here. The figure shows one possible data flow path labelled in roman numerals $ I - IV $ in that order; this sequence also resembles the training of the model. The latent representation of an encoded image is gradually randomized by injecting noise in a Markov chain. This is followed by predicting the noise (and the denoised latent representation) conditioned on some information about what has to be recovered – in this case, a text description of the original image. Finally, the predicted denoised latent representation can be decoded to reconstruct the original image.

Figure 1. The Stable Diffusion model consists of multiple stages across several blocks as shown here. The figure shows one possible data flow path labelled in roman numerals $ I - IV $ in that order; this sequence also resembles the training of the model. The latent representation of an encoded image is gradually randomized by injecting noise in a Markov chain. This is followed by predicting the noise (and the denoised latent representation) conditioned on some information about what has to be recovered – in this case, a text description of the original image. Finally, the predicted denoised latent representation can be decoded to reconstruct the original image.

As shown in Figure 1, there are several pieces to the Stable Diffusion model. In the following sections, we will look into each block in more detail alongside the code snippets that make these blocks work.

Autoencoder

The autoencoder used in Stable Diffusion is similar to that in the VQGAN paper [1]. It serves the purpose of perceptual compression, and is trained using a perceptual loss and a patch-based adversarial objective. The authors argue in the paper that together, these objectives are responsible for learning semantic variation and instilling realism in the generations. Additionally, there exist two options for a regularization objective to ensure that the learned latent space distribution is zero-centered and has low variance. All objectives used for training the autoencoder are explained in more detail in section: [training details > autoencoder].

If the regularization objective used is KL divergence between the predicted distribution and $\mathcal{N}(0, 1)$, we’d be training the AutoencoderKL which we will understand in detail. AutoencoderKL is defined as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
'''Source: ldm/models/autoencoder.py'''
class AutoencoderKL(pl.LightningModule):
    def __init__(self, modelconfig, lossconfig, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.encoder = Encoder(**modelconfig)
        self.decoder = Decoder(**modelconfig)
        self.loss = LPIPSWithDiscriminator(lossconfig)
        self.quant_conv = torch.nn.Conv2d(2*modelconfig["z_channels"], 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, modelconfig["z_channels"], 1)
        # skipped irrelevant lines

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        z = posterior.sample()
        dec = self.decode(z)
        return dec, posterior

The encoding phase of the autoencoder outputs a posterior over the latent space given an input image x. The posterior is assumed to have a diagonal covarinace matrix, so the posterior can be identified using 2*embed_dim parameters (the means and variances of all dimensions). Then, a sample z from the posterior is decoded to obtain a prediction of the original image in the same space as x. This model is optimized using LPIPSWithDiscriminator which includes a perceptual loss, an adversarial loss and a regularizing KL divergence loss; we shall return to this later.

Further, we note here that the same z undergoes the diffusion process to obtain a noisy z so that we can train a denoising diffusion probabilistic model to estimate the denoised version given some conditioning.

The Encoder and Decoder defined above comprise well-known modules in deep learning such as residual blocks [2] and self-attention [10] in an intuitive architecture. So, we skip a detailed discussion of their implementations here which can be found at ldm/modules/diffusionmodules/model.py.

Diffusion Model

The diffusion model in Stable Diffusion has a U-Net architecture [8] with support for processing timestep embeddings, and cross-attention between the context embedding and the latent representation to be denoised.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
'''Source: ldm/modules/diffusionmodules/openaimodel.py'''

class UNetModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = [torch.nn.Conv2d(in_channels, model_channels, 3)]
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                self.input_blocks += [
                    ResBlock(),                 # processes z and timestep embed
                    AttentionBlock(),           # cross-attn between z and context embed
                ] # skipped arguments to modules
            self.input_blocks += [ResBlock()]   # processes z and timestep embed
        self.input_blocks = torch.nn.Sequential(*self.input_blocks)


        self.middle_block = torch.nn.Sequential(*[
            ResBlock(),                         # processes z and timestep embed
            AttentionBlock(),                   # cross-attn between z and context embed
            ResBlock()                          # processes z and timestep embed
        ]) # skipped arguments to modules

        self.output_blocks = []
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                self.output_blocks += [
                    ResBlock(),                 # processes z, skip conn from inp blocks, timestep embed
                    AttentionBlock()            # cross-attn between z and context embed
                ] # skipped arguments to modules
            self.input_blocks += [ResBlock()]   # processes z and timestep embed
        self.output_blocks = torch.nn.Sequential(*self.output_blocks)

        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            torch.nn.Conv2d(model_channels, out_channels, 3),
        )

Similar to the autoencoder, the U-Net model comprises some well-known techniques. First, we see an MLP head to project timestep information to a high-dimensional embedding space. The input, middle and output blocks are composed of residual and cross-attention blocks which enables the generative model to lern a conditional distribution for the reverse diffusion process. Finally, the output blocks also involve skip connections from the input blocks, a feature typical to U-Net style architectures.

Conditioning Model

The Stable Diffusion models are paramaterized to be able to perform both conditional and non-conditional image generation. Towards conditional image generation, we would require a mechanism to encode the human-parsable conditioning such as text or image to be fed into the cross-attention layers of the U-Net above. So, the authors have designed domain-specific conditioning models for the various data domains that they train their models on:

  • Text-to-image generation: At the time of writing of this article, Stable Diffusion’s text-to-image generation uses text embeddings from a frozen CLIP [6] model to obtain conditioning embeddings.

  • Class-conditional generation: In order to condition on class-information, for example, classes of ImageNet, they use a single embedding layer to transform a class index to its corresponding representation.

Training Details

The Stable Diffusion model is trained in two stages: (1) training the autoencoder alone, i.e., $I, IV$ only in figure 1, and (2) training the diffusion model alone after fixing the autoencoder, i.e., $I - IV$ in figure 1 but keeping $I, IV$ frozen. Let’s look at each phase in more detail.

Autoencoder

The autoencoder is trained using the LPIPSWithDiscriminator. This loss has the following components:

  • Reconstruction Loss ($\mathcal{L}_{r}$): The mean-squared-error between the reconstructions and the predicted images in the pixel space

  • Perceptual Loss ($\mathcal{L}_{p}$): The Learned Perceptual Image Patch Similarity (LPIPS) metric loss. In a nutshell, this objective tries to minimize the perceptual differences between the reconstructed and original images by making sure that the VGG features of both the images are close to each other. This is based on the empirical finding that differences in the representation spaces of computer vision models such as VGG capture perceptual distance very well compared to differences in image space or using metrics like PSNR [12].

  • KL divergence Loss ($\mathcal{L}_{k}$): This is a regularization term added to the objective of the autoencoder to ensure that the latent dimension of the autoencoder is zero-centered and has low variance. Towards this the KL divergence is measure with respect to $\mathcal{N}(0, 1)$.

  • Adversarial Loss ($\mathcal{L}_{a}$): The autoencoder is also trained against a Patch-GAN [patchgan] discriminator, i.e., the autoencoder is trained to maximize the scores of a discriminator which looks at various patches of any image to make a collective decision about whether the image is real or fake.

The adversarial loss is scaled with an adaptive weight such that there is a balance between the adversarial and reconstruction objectives [1]:

$$ \lambda_{adv} = \frac{\nabla \mathcal{L}_r + \mathcal{L}_p}{\nabla \mathcal{L}_a + \delta} $$

Intuitively, if the gradients of $\mathcal{L}_a$ are small, then we should scale up the loss appropriately to allow faster training.

Diffusion Model

Having trained the autoencoder, we can train the diffusion model to denoise noisy latent representations from the autoencoder. This involves gradually injecting noise into the outputs of the encoder, and training a diffusion model to learn the reverse process of this noise injection.

The noise is injected into the latent representation in a Markov chain of $T$ time steps. Given a representation at some time step $t$, we can obtain the next time step noisy representation as:

$$ q(x_{t + 1} | x_t) = \mathcal{N}(\sqrt{1 - \beta_t}x_{t - 1}, \beta_t \mathbf{I}) $$

It can be shown that we can skip over time steps instead of iterating over each time step as:

$$ q(x_{t} | x_0) = \mathcal{N}(\sqrt{\bar{\alpha_t}}x_0, (1 - \bar{\alpha_t})\mathbf{I}) $$

This is performed in code as:

1
2
3
4
5
6
7
'''Source: ldm/models/diffusion/ddpm.py'''
def q_sample(self, x_start, t):
    noise = torch.randn_like(x_start)
    return (
        self.sqrt_alphas_cumprod[t] * x_start +
        self.sqrt_one_minus_alphas_cumprod[t] * noise
    )

The reverse diffusion process can be shown to be:

$$ q(x_{t - 1} | x_t, x_0) = \mathcal{N}(\tilde{\mu}(x_t, x_0), \tilde{\beta}_t) $$

This is performed in code as:

1
2
3
4
5
6
7
'''Source: ldm/models/diffusion/ddpm.py'''
def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
            self.posterior_mean_coef1[t] * x_start +
            self.posterior_mean_coef2[t] * x_t
    )
    return posterior_mean

The various constants used in the above equations form a part of what’s known as a scheduler, the details of which are presently skipped here. An excellent resource for this is [11].

Finally, having obtained a prediction for the reconstructed image, the Stable Diffusion model is trained to minimize the mean squared error between the original image and the predicted image. This is the “simple” version of the model; other versions of the model involve other objectives as well, such as an objective to learn the variances in the reverse diffusion process as well.

Sampling

The reverse denoising described above can be very slow if done iteratively for each time step. So, different methods have been proposed to speed up the reverse diffusion process. Each method builds on its interpretation of the reverse diffusion process and has its pros and cons. Let’s look at some of these in detail:

  • Denoising Diffusion Implicit Models (DDIM) [9]: This method uses a deterministic process (corresponding to 0 variance) for modeling reverse diffusion using strided time steps. This deterministic process can potentially map noisy samples back to the original image in the limit of a large number of steps. If we allow a learnable non-zero variance then this method is known as Denoising Diffusion Probabilistic Model (DDPM) [3].

  • Pseudo Linear Multi-Step Method (PLMS) [5]: This method views the reverse diffusion process as a differential equation and uses pseudo-numerical methods to solve the differential equation. They modify classical numerical methods by using a gradient part and non-linear transfer part while structuring the differential equation. Finally, they argue that DDIM is a special case of PLMS sampling and compare the performances of the two methods.

Figure 2. Image generation results measured in terms of FID and time per step. S-PNDM uses gradients from two time-steps whereas F-PNDM uses from four. Table from [3].

Figure 2. Image generation results measured in terms of FID and time per step. S-PNDM uses gradients from two time-steps whereas F-PNDM uses from four. Table from [3].

We can see that PLMS has a higher per-step execution time than DDIM. However, the FID from PLMS is much lower than DDIM for the same number of steps.

References

[1] Patrick Esser, Robin Rombach, and Bjorn Ommer. Taming transformers for high-resolution image synthesis, 2020.

[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition.

[3] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models, 2020.

[4] Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros. Image-to-image translation with conditional adversarial networks, 2016.

[5] Luping Liu, Yi Ren, Zhijie Lin, and Zhou Zhao. Pseudo numerical methods for diffusion models on manifolds, 2022.

[6] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, and Ilya Sutskever. Learning transferable visual models from natural language supervision, 2021.

[7] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Bjorn Ommer. High-resolution image synthesis with latent diffusion models, 2021.

[8] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation, 2015.

[9] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models, 2020.

[10] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need, 2017.