A look at Diffusion Models

Nathan Bailey
12 min readJul 3, 2024

--

Diffusion Process [3]

Following my previous blog on the Pixel CNN architecture, I continued investigating deep learning generative models. One architecture that has always intrigued me is diffusion models, which build the basis for generative text-to-image models such as DALLE.

The full code for this blog can be found on GitHub and was adapted from [1].

Architecture

Diffusion models aim to denoise an image over time (a small series of steps). The idea is that we should be able to input random noise and apply the diffusion model over time to denoise the image to produce a unique image that looks like the training data.

Forward Diffusion

Given training images, to train a diffusion model, we first must noise images before inputting them into the model, this is called the forward diffusion step. For this, we start with the noising process, which takes an image and continually applies noise to it over a large number of steps. This is achieved using the following formula which adds a small amount of Gaussian noise to an image.

Noised Image Formula [1]

Where e is a standard Gaussian with zero mean and unit variance. When we multiply a standard Gaussian with a constant c, it gives a Gaussian with zero mean and constant c variance. Assuming that our input image has zero mean and unit variance, and given the fact that Var(X+Y) = Var(X) + Var(Y), our final image Xt will have zero mean and unit variance, which will approximate a standard Gaussian.

We can modify the equation to allow us to jump straight from the input image to an image at a given time step T. This is given by the following equation, where:

Noised Image Formula (reworked) [1]

Where e is the sum of all added noise terms (e) over all the previous time steps.

The noise parameter Bt does not have to be constant at each timestep, we can choose its value using a diffusion schedule. This scheduler describes how the value of Beta changes with a changing timestep t.

There are a few options for a diffusion schedule, a simple one can be linear, where the value of Beta increases linearly with t. This ensures that in the initial stages, we take smaller noising steps.

We can see a linear diffusion schedule implemented in the following code. Diffusion times are given as a floating point tensor in the range 0.0–1.0:

def linear_diffusion_schedule(
diffusion_times: tf.Tensor,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Linear Diffusion Schedule Function."""
min_rate = 0.0001
max_rate = 0.02
betas = min_rate + diffusion_times * (max_rate - min_rate)
alpha = 1 - betas
alpha_bars = tf.math.cumprod(alpha)
signal_rates = tf.sqrt(alpha_bars)
noise_rates = tf.sqrt(1 - alpha_bars)
return noise_rates, signal_rates

A more efficient process is a cosine schedule, which defines the noise term using the following formula:

Cosine Schedule [1]

Given the classic trigonometric identity cos² x + sin² x = 1, this gives us the following formula for noising an image:

Noised Image Formula with Cosine Schedule [1]

This is implemented in the following code, where t/T are the diffusion times (in the range 0.0–1.0):

def cosine_diffusion_schedule(
diffusion_times: tf.Tensor,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Cosine Diffusion Schedule Function."""
signal_rates = tf.cos(diffusion_times * (math.pi / 2))
noise_rates = tf.sin(diffusion_times * (math.pi / 2))
return noise_rates, signal_rates

The original paper added an offset term and scaling, which is shown in the following code block. This prevents the noising steps from being too small at the beginning of the process.

def offset_cosine_diffusion_schedule(
diffusion_times: tf.Tensor,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Cosine Diffusion Schedule Function with Offset and Scaling."""
min_signal_rate = 0.02
max_signal_rate = 0.95
start_angle = tf.acos(max_signal_rate)
end_angle = tf.acos(min_signal_rate)
diffusion_angles = start_angle + diffusion_times * (
end_angle - start_angle
)
signal_rates = tf.cos(diffusion_angles)
noise_rates = tf.sin(diffusion_angles)
return noise_rates, signal_rates

A comparison of the schedules can be seen below. The noise level ramps up much more slowly in the cosine schedule. Adding noise much more slowly has been shown to improve training efficiency and increase the quality of image generation.

Comparison of Diffusion Schedules [1]

Training a Network

Now, given a noised image, we need some way to denoise this image using a neural network, as this is the aim of our diffusion model. Given a noised image Xt at a random timestep t, the network will aim to predict the total noise added to that image. Then, as we know how much noise has actually been added to the image, we can use the mean absolute error to generate a loss value.

We keep 2 copies of this network, one for training and another one called the exponential moving average (EMA) network for prediction. The EMA network’s weights are an exponential moving average of the weights of the training network over previous timesteps. This is done as the EMA network was found to be more robust for image generation when compared to the trained network.

The network used in the diffusion model is called the U-Net and is shown below.

UNet Architecture [1]

As seen from the diagram, the network is made of 2 halves: a downsampling half and an upsampling half. This is analogous to the encoder/decoder structure of an autoencoder.

The downsampling section compresses the spatial size of the image, and the upsampling section upscales the spatial dimension of the image, back to the original size.

The downsampling half consists of DownBlocks, which reduce the spatial dimensions of an input whilst increasing the number of channels. Whereas, the upsampling half is comprised of UpBlocks which increase the spatial dimensions of an input whilst reducing the number of channels.

The U-Net architecture also introduces skip connections between the Up and Down blocks, which enables information to shortcut parts of the network and flow through the later layers.

The final output of the model is the predicted noise, which is a tensor the same size as the input images.

Sinusoidal Embedding

As seen in the diagram, the inputs to the network are the noise variance (1−α(bar)) and the noisy images. The noise variance is passed through a sinusoidal embedding block before being upsampled so that it can be concatenated with the output of a convolutional block. The sinusoidal embedding block converts a scalar value into a higher dimensional vector to provide a more complex representation.

Given a scalar value, the sinusoidal embedding produces the following vector:

Sinusoidal Embedding Formula [1]

L is a hyperparameter and f is given by

We can implement sinusoidal embedding in a small Python function shown below. We first start by generating a tensor of length 16 (noise embedding size is 32) of exponential values equally spaced between ln(1.0) and ln(1000).

These are then multiplied by 2π and are combined to give us our embeddings.

def sinusoidal_embedding(
x: tf.Tensor, noise_embedding_size: int
) -> tf.Tensor:
"""Sinusoidal Embedding Function."""
frequencies = tf.exp(
tf.linspace(
tf.math.log(1.0),
tf.math.log(1000.0),
noise_embedding_size // 2,
)
)
angular_speeds = 2.0 * math.pi * frequencies
embeddings = tf.concat(
[tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
)
return embeddings

Up/Down Blocks

The structure of Up and Down blocks are shown in the following diagram:

Up and Down Blocks [1]

As seen, the main components of these blocks are residual blocks, which are shown below.

Residual Block [2]

These blocks pass the input tensor through a few convolutional layers, adding the input (element-wise) to the output of the final block. They allow us to build networks without suffering from the degradation problem, which occurs when deeper networks result in higher training and test errors compared to shallower networks. They have also been shown to increase the accuracy of networks. [2]

We can implement a ResidualBlock in Keras using the following code:

class ResidualBlock(keras.layers.Layer):
"""Residual Block Layer."""

def __init__(self, width: int, **kwargs: Any) -> None:
"""Init variables and layers."""
super().__init__()
self.width = width
self.downsample_layer = keras.layers.Conv2D(width, kernel_size=1)
self.batch_norm_layer = keras.layers.BatchNormalization(
center=False, scale=False
)
self.conv_layer_1 = keras.layers.Conv2D(
width, **kwargs, activation=keras.activations.swish
)
self.conv_layer_2 = keras.layers.Conv2D(width, **kwargs)

def call(self, input: tf.Tensor) -> tf.Tensor:
"""Forward pass."""
if self.width == input.shape[3]:
residual = input
else:
residual = self.downsample_layer(input)
x = self.batch_norm_layer(input)
x = self.conv_layer_1(x)
x = self.conv_layer_2(x)
return residual + x

In a simple block, we pass the input through a batch normalization layer and 2 convolutional layers. If the depth of the input does not match the desired output depth, we pass the input through a downsample layer to ensure they match. This enables them to be added together. We will be using “same” padding and a stride of 1, so there is no need to adjust the spatial dimensions of the input as it will always match the output.

Down Blocks are implemented in the following code.

class DownBlock(keras.layers.Layer):
"""Down Block Layer."""

def __init__(self, block_depth: int, width: int, **kwargs: Any) -> None:
"""Init Layers."""
super().__init__()
self.residual_blocks = [
ResidualBlock(width, **kwargs) for _ in range(block_depth)
]
self.average_pool = keras.layers.AveragePooling2D(pool_size=2)

def call(self, input: tf.Tensor) -> tuple[tf.Tensor, list[tf.Tensor]]:
"""Forward pass."""
x = input
skips = []
for residual in self.residual_blocks:
x = residual(x)
skips.append(tf.identity(x))
x = self.average_pool(x)
return x, skips

In this block, we pass the input through the residual blocks, making sure to capture the output of each block into a list to use as skip connections later. Finally, we downsample the spatial dimensions of the output feature map using an average pooling layer.

Similarly, Up Blocks are implemented in the following block.

class UpBlock(keras.layers.Layer):
"""Up Block Layer."""

def __init__(self, block_depth: int, width: int, **kwargs: Any) -> None:
"""Init Layers."""
super().__init__()
self.residual_blocks = [
ResidualBlock(width, **kwargs) for _ in range(block_depth)
]
self.up_sampling = keras.layers.UpSampling2D(
size=2, interpolation="bilinear"
)
self.concat_layer = keras.layers.Concatenate()

def call(
self, input_list: list[tf.Tensor | list[tf.Tensor]]
) -> tf.Tensor:
"""Forward pass."""
x, skips = input_list
x = self.up_sampling(x)
for residual in self.residual_blocks:
x = self.concat_layer([x, skips.pop()])
x = residual(x)
return x

This follows a very similar pattern to the Down Block, instead concatenating the incoming skip tensor with the input before inputting it into the residual block.

Finally, we can implement a UNet Model in the following function. This simply combines the blocks as shown in the diagram above.

def create_unet_model(
image_size: int,
block_depth: int,
filter_list: list[int],
noise_embedding_size: int,
) -> keras.models.Model:
"""Create a UNET Model from Building Blocks."""
assert len(filter_list) == 4

noisy_images = keras.layers.Input(shape=(image_size, image_size, 3))
x = keras.layers.Conv2D(32, kernel_size=1)(noisy_images)

noise_variances = keras.layers.Input(shape=(1, 1, 1))
noise_embedding = keras.layers.Lambda(
sinusoidal_embedding,
arguments={"noise_embedding_size": noise_embedding_size},
)(noise_variances)
noise_embedding = keras.layers.UpSampling2D(
size=image_size, interpolation="nearest"
)(noise_embedding)

x = keras.layers.Concatenate()([x, noise_embedding])

skips_total = []
for filter_width in filter_list[:-1]:
x, skips = DownBlock(
block_depth=block_depth,
width=filter_width,
kernel_size=3,
padding="same",
)(x)
skips_total += skips

for _ in range(2):
x = ResidualBlock(
width=filter_list[-1], kernel_size=3, padding="same"
)(x)
for filter_width in filter_list[:-1][::-1]:
x = UpBlock(
block_depth=block_depth,
width=filter_width,
kernel_size=3,
padding="same",
)([x, [skips_total.pop(), skips_total.pop()]])

x = keras.layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

return keras.models.Model([noisy_images, noise_variances], x)

Diffusion Model

We combine the UNet model into a larger Diffusion Model class, this allows us to implement a custom train step, where we can noise an image and use the UNet model to predict the noise added to the image.

Our init function sets up the model, creating the 2 networks, the diffusion schedule and the loss tracker.

class DiffusionModel(keras.models.Model):
"""Diffusion Model Class."""

def __init__(
self,
image_size: int,
batch_size: int,
ema_value: float,
noise_embedding_size: int,
) -> None:
"""Init variables and model."""
super().__init__()
self.image_size = image_size
self.batch_size = batch_size
self.ema_value = ema_value
self.normalizer = keras.layers.Normalization()
self.network = create_unet_model(
image_size, 2, [32, 64, 96, 128], noise_embedding_size
)
self.ema_network = keras.models.clone_model(self.network)
self.diffusion_schedule = offset_cosine_diffusion_schedule
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")

@property
def metrics(self) -> keras.metrics.Metric:
"""Register Metrics."""
return [self.noise_loss_tracker]

Next, we set up the denoise function, which takes in a noisy image, predicts the noise added to the image and then de-noises the image based on this predicted noise.

def denoise(
self,
noisy_images: tf.Tensor,
noise_rates: tf.Tensor,
signal_rates: tf.Tensor,
training: bool = True,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Predict noise and denoise images with the predicted noise."""
if training:
network = self.network
else:
network = self.ema_network

pred_noises = network(
[noisy_images, noise_rates**2], training=training
)
pred_images = (
noisy_images - noise_rates * pred_noises
) / signal_rates
return pred_noises, pred_images

Our custom train step is below. Here, we generate a noised image, based on an input image at a random timestep t. We then take this into our denoise function to predict the noise added to this image. We can then find our loss and update the weights of the network.

def train_step(self, images: tf.Tensor) -> dict[str, tf.Tensor]:
"""Training step."""
images = self.normalizer(images, training=True)
noises = tf.random.normal(
shape=(self.batch_size, self.image_size, self.image_size, 3)
)

# Generate an image x(t) at a random timestep t
diffusion_times = tf.random.uniform(
shape=(self.batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)

noisy_images = signal_rates * images + noise_rates * noises

with tf.GradientTape() as tape:
pred_noises, _ = self.denoise(
noisy_images, noise_rates, signal_rates
)
noise_loss = self.loss(noises, pred_noises)

gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.network.trainable_weights)
)
self.noise_loss_tracker.update_state(noise_loss)

for weight, ema_weight in zip(
self.network.weights, self.ema_network.weights
):
ema_weight.assign(
self.ema_value * ema_weight + (1 - self.ema_value) * weight
)
return {m.name: m.result() for m in self.metrics}

Sampling new images from the Network

Once we have trained a network to successfully predict the noise from an image, we need a process to generate a new image based on a noisy image.

As seen in the training process, the model is trained to predict the total amount of noise added to the image. We could move from the noisy image straight to the original image at timestep 0 by undoing all of the noise at once, however, in practice, this often leads to poor performance. Instead, we undo the noise gradually over a period of steps, moving from Xt to Xt-1.

Given a noised image and predicted noise from the network. We first estimate X0 and then reapply the predicted noise in the direction pointing to Xt to estimate Xt-1. This process is repeated using Xt-1 until we reach X0. For each iteration, we predict the estimated noise using the network and our input image.

Generated Image Formula [1]

We implement this in our Diffusion Model class as shown below.

Given a number of iterations (steps), this function implements the above formula, outputting our final predicted images at the end of the final iteration.

We first work out the noise and signal rates at the current timestep and use this to generate our X0 images from the model. We then re-applied the noise, giving us our current images at timestep T-1.

def reverse_diffusion(
self, initial_noise: tf.Tensor, diffusion_steps: int
) -> tf.Tensor:
"""
Reverse diffusion process.
Take a noisy image and denoise it over timesteps
to produce a clean, generated image.
"""
num_images = initial_noise.shape[0]
step_size = 1.0 / diffusion_steps
current_images = initial_noise
# Timesteps are float values between 0 and 1
# Start from final step t (1) and work backwards
for step in range(diffusion_steps):
diffusion_times = (
tf.ones((num_images, 1, 1, 1)) - step * step_size
)
noise_rates, signal_rates = self.diffusion_schedule(
diffusion_times
)
pred_noises, pred_images = self.denoise(
current_images, noise_rates, signal_rates, training=False
)
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(
next_diffusion_times
)
current_images = (
next_signal_rates * pred_images
+ next_noise_rates * pred_noises
)
return pred_images

We can wrap this into a generate function which can then be called to generate images on the fly. This creates random noise, generates images and then denormalizes them, ready to be saved.

def generate(
self,
num_images: int,
diffusion_steps: int,
initial_noise: tf.Tensor | None = None,
) -> tf.Tensor:
"""Generate image from the model."""
if initial_noise is None:
initial_noise = tf.random.normal(
shape=(num_images, self.image_size, self.image_size, 3)
)
generated_images = self.reverse_diffusion(
initial_noise, diffusion_steps
)
generated_images = self.denormalize(generated_images)
return generated_images

We can incorporate this generate function into a Keras callback, which is then used to generate images at each epoch to keep an eye on our training process.

class ImageGenerator(keras.callbacks.Callback):
"""Image Generator Callback."""

def __init__(self, num_img: int, num_diffusion_steps: int) -> None:
"""Init variables."""
super().__init__()
self.num_img = num_img
self.num_diffusion_steps = num_diffusion_steps

def on_epoch_end(self, epoch: int, logs: None = None) -> None:
"""Generate and save images on epoch end."""
generated_images = self.model.generate(
num_images=self.num_img, diffusion_steps=self.num_diffusion_steps
).numpy()
display(
generated_images,
save_to=f"./output/generated_image_epoch_{epoch}.png",
)

model = DiffusionModel(
image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
ema_value=EMA,
noise_embedding_size=NOISE_EMBEDDING_SIZE,
)
model.normalizer.adapt(train_data)
model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
),
loss=keras.losses.mean_absolute_error,
)
model.network.summary(expand_nested=True)

image_generator = ImageGenerator(
num_img=100,
num_diffusion_steps=NUM_DIFFUSION_STEPS
)
model.fit(train_data, epochs=EPOCHS, callbacks=[image_generator], verbose=2)

After training for 100 epochs on the Oxford-102 Flower Dataset, we can see from the results below, that our diffusion model does an excellent job of generating images from random noise!

Training Images
Generated Images

Conclusions

This blog outlined the architecture and showed the implementation of a diffusion model. We saw great success after 100 epochs, with the diffusion model able to generate novel instances of the training data just by inputting random noise!

The full code for this blog can be found on GitHub

  1. Generative Deep Learning, David Foster
  2. https://arxiv.org/abs/1512.03385

--

--

Nathan Bailey

MSc AI and ML Student @ ICL. Ex ML Engineer @ Arm, Ex FPGA Engineer @ Arm + Intel, University of Warwick CSE Graduate, Climber. https://www.nathanbaileyw.com