A Look at the Pixel CNN Architecture
In an attempt to become more versed with generative AI foundations, I have recently embarked on a learning spree, exposing myself to several fundamental generative neural networks. I have been following the excellent book called Generative Deep Learning by David Foster [1]. This book outlines, explains and implements many key generative neural networks in Keras. One particular method that caught my eye was the pixel convolutional neural network.
The full code for this blog can be found on GitHub.
Architecture
This generative neural network generates new images pixel by pixel by predicting the likelihood of the next pixel, based on the pixels preceding it. Essentially it takes in the current pixels and outputs the probability distribution for the next pixel, I.e. a probability is outputted for each possible pixel value. We can then sample a pixel value from this probability distribution.
To apply a regressive problem like this on images, we must first place an order on the pixels and ensure that filters only see the pixels that come before the one we are currently predicting. Assuming that we have a single-channel (grayscale) image, we will order pixels from left to bottom right, so the top left pixel will have an index of 0 and the pixel in the same row next to it will have an index of 1.
The PixelCNN then uses masked convolutional layers, where some filter weights are set to 0 by the use of a filter mask. We can see how this mask is implemented below for a single-channel image below. The mask is formed in this way so that when the filter is computing the output value for a specific index, it does not use any values of the input for the current and succeeding pixels. For a single-channel image input, the ones and zeros are extended across all channel dimensions of the filter.
Two types of masks are created, a type A and a type B mask. A type A mask sets the value of the central pixel to 0 and a type B mask sets the value of the central pixel to 1.
A type A filter is always used in the first layer of the network, this is to ensure that the layer cannot use the current pixel of the image that we would like to predict. However, the successive layers can use masks of type B so they can use the current predicting pixel index of the output of previous layers. This works because the current pixel of the image was never used and therefore allows us to work with more data in our feature extraction.
We can implement a masked convolutional layer using the following code:
class MaskedConv2D(keras.layers.Layer):
def __init__(self, mask_type: Literal["A", "B"], **kwargs: Any) -> None:
super().__init__()
self.mask_type = mask_type
self.conv_layer = keras.layers.Conv2D(**kwargs)
def build(self, input_shape: tuple[int]) -> None:
self.conv_layer.build(input_shape)
#Kernel shape is F1xF2xDepthxNum Filters
kernel_shape = self.conv_layer.kernel.get_shape()
self.mask = np.zeros(shape=kernel_shape)
#Set half - 1 rows to 1.0s
self.mask[:kernel_shape[0] // 2, ...] = 1.0
#Set the half rows tp 1.0 up to the middle - 1 column
self.mask[kernel_shape[0] // 2, :kernel_shape[1] // 2, ...] = 1.0
#If the mask type is B, set the middle value to 1.0
if self.mask_type == "B":
self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
def call(self, input: tf.Tensor) -> tf.Tensor:
self.conv_layer.kernel.assign(self.conv_layer.kernel*self.mask)
return self.conv_layer(input)
As seen, for a given kernel size of NxN, we set the first (N/2–1) rows to 1, then in the N/2 row, we set all values to 1 up to the central value. If the mask is of type A we leave the central value as 0. If it is of type B we set it to 1.
A PixelCNN is made up of residual blocks that take the following shape. We halve the filters in the first convolutional layer, apply the masked layer and then upscale back to the original number of filters in the final convolutional layer.
This is shown in the code below:
class ResidualBlock(keras.layers.Layer):
def __init__(self, filters: int) -> None:
super().__init__()
self.conv_layer_1 = keras.layers.Conv2D(filters = filters // 2, kernel_size=1, activation="relu")
self.pixel_conv = MaskedConv2D(mask_type="B", filters = filters // 2, kernel_size=3, activation="relu", padding="same")
self.conv_layer_2 = keras.layers.Conv2D(filters=filters, kernel_size=1, activation="relu")
def call(self, input: tf.Tensor) -> tf.Tensor:
"""
Spatial or depth dimensions will not change between input and output.
So no need to downsample anything, we just add input and output.
"""
x = self.conv_layer_1(input)
x = self.pixel_conv(x)
x = self.conv_layer_2(x)
return (input + x)
We build up the PixelCNN using masked convolutional layers and residual blocks, the following diagram shows the architecture from the original paper. The final layer is a convolutional layer with a 1x1 filter using a softmax or sigmoid activation. It outputs the same spatial dimensions as the input image, with the number of channels equal to the number of possible values a pixel can take.
We end up with a probability distribution over a pixel's possible values. We use the sparse categorical cross-entropy loss, so our labels are the actual pixel values. We encourage the probability of the true pixel value to be as large as possible.
An example of a pixel CNN network is shown in the following function.
def build_model(input_shape: tuple[int], num_residual_blocks: int, num_filters: int, pixel_levels: int) -> keras.models.Model:
input = keras.layers.Input(shape=input_shape)
x = MaskedConv2D(mask_type='A', filters=num_filters, kernel_size=7, activation="relu", padding="same")(input)
for _ in range(num_residual_blocks):
x = ResidualBlock(filters=num_filters)(x)
for _ in range(2):
x = MaskedConv2D(mask_type="B", filters=num_filters, kernel_size=1, padding="valid", activation="relu")(x)
output = keras.layers.Conv2D(filters=pixel_levels, kernel_size=1, activation="softmax", padding="valid")(x)
return keras.models.Model(input, output)
Training and Generating
To train the network, we simply input our images, which are typically scaled so that their values are in the range 0–1. The labels are then the same images, typically with their values in an integer range e.g. 0–255.
To generate a new image, we have to generate each pixel one by one. We input a structure of zeros, the same shape as the image into the network. The output provides us with the probability distribution for the value of the first pixel, we sample from this and assign the sampled value to the structure. We then repeat this process sampling for the next pixel and so on.
This is best shown in the code below.
def sample_from(self, probs: tf.Tensor, temperature: float) -> np.ndarray:
probs = probs ** (1 / temperature)
probs = probs / np.sum(probs)
#Choose a random number using the distribution
return np.random.choice(len(probs), p=probs)
def generate(self, temperature: float) -> np.ndarray:
generated_images = np.zeros(shape=(self.num_img,) + self.model.input_shape[1:])
_, rows, cols, channels = generated_images.shape
for row in range(rows):
for col in range(cols):
for channel in range(channels):
#Predict pixels one by one using the previously predicted pixels
probs = self.model.predict(generated_images, verbose=0)[:, row, col, :]
generated_images[:, row, col, channel] = [self.sample_from(x, temperature) for x in probs]
generated_images[:, row, col, channel] /= self.pixel_levels
return generated_images
Experiments
For our first experiment, we will attempt to train a pixel CNN using the MNIST dataset. We first preprocess the dataset to reduce the number of possible values a pixel can take. Originally, the MNIST images can take values in the range of 0–255. This creates a complex pixel CNN as the output of the final layer becomes very deep. As most pixels are black, we restrict the possible pixel values to 2.
We create input data that will be floating points in the range of 0–1 and output data that will take the values 0 or 1. This way we are inputting image data as we usually do (a floating point value in the range 0–1) whilst having labels that the loss function expects.
IMAGE_SIZE = 28
PIXEL_LEVELS = 2
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
def preprocess_images(images: np.ndarray) -> tuple[np.ndarray]:
imgs_int = np.expand_dims(images, -1)
imgs_int = tf.image.resize(imgs_int, (IMAGE_SIZE, IMAGE_SIZE)).numpy()
imgs_int = (imgs_int / (256 / PIXEL_LEVELS)).astype(np.uint8)
imgs = imgs_int.astype("float32")
imgs = imgs / PIXEL_LEVELS
return imgs, imgs_int
input_data, output_data = preprocess_images(x_train)
We can then create our model using the above function, we used 5 residual blocks and a filter depth of 128. I also found that better results were achieved by keeping the number of filters the same as the input in the first convolutional layer of the residual block.
We also create a callback object so that we can generate images every epoch to monitor how well the network is training. Our generate function, shown below generates a set number of images and is executed at each epoch end and saved to disk.
class ImageGenerator(keras.callbacks.Callback):
def __init__(self, num_img: int, pixel_levels: int, save_dir: str = 'output') -> None:
self.num_img = num_img
self.pixel_levels = pixel_levels
self.save_dir=save_dir
def sample_from(self, probs: tf.Tensor, temperature: float) -> np.ndarray:
probs = probs ** (1 / temperature)
probs = probs / np.sum(probs)
#Choose a random number using the distribution
return np.random.choice(len(probs), p=probs)
def generate(self, temperature: float) -> np.ndarray:
generated_images = np.zeros(shape=(self.num_img,) + self.model.input_shape[1:])
_, rows, cols, channels = generated_images.shape
for row in range(rows):
for col in range(cols):
for channel in range(channels):
#Predict pixels one by one using the previously predicted pixels
probs = self.model.predict(generated_images, verbose=0)[:, row, col, :]
generated_images[:, row, col, channel] = [self.sample_from(x, temperature) for x in probs]
generated_images[:, row, col, channel] /= self.pixel_levels
return generated_images
def on_epoch_end(self, epoch, logs=None) -> None:
generated_images = self.generate(temperature=1.0)
display(generated_images, save_to=f"{self.save_dir}/generated_img_{epoch}.png")
We train the network for 100 epochs, using a batch size of 128. The starting learning rate is set at 0.01, reducing it by a factor of 0.1 each time the training loss plateaus.
BATCH_SIZE = 128
EPOCHS = 100
opt = keras.optimizers.Adam(learning_rate=0.01)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.1, patience=3, verbose=1, min_lr=1e-7, min_delta=1e-4)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss', verbose=1, patience=10)
model.compile(optimizer=opt, loss="sparse_categorical_crossentropy")
img_generator_callback = ImageGenerator(num_img=100, pixel_levels=PIXEL_LEVELS, save_dir=SAVE_DIR)
model.fit(input_data, output_data, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[lr_scheduler, early_stopping, img_generator_callback], verbose=2)
We can inspect the images generated after the 100th epoch, whilst they are not perfect, we can see some numbers being correctly generated by the network.
After tackling the MNIST dataset, we now turn our attention to a more complex dataset, the fashion MNIST dataset. As with the MNIST dataset, the fashion MNIST dataset can take values from 0–255.
For the MNIST dataset, we used 2-pixel levels. For the fashion MNIST dataset, as the value of the pixels are not always black, we will use 16-pixel levels. The figure below shows a comparison between the original images and the images after they have been reduced to 16 levels. We can see that there is not much difference between the original and reduced images. We also reduce the spatial dimensions of the images from 32x32 to 16x16 to enable faster training of the network.
As with the MNIST dataset, after 100 epochs, the generated images are not perfect, but we can see the network generating correct classes of images.
Colour Images
As stated before, for a single-channel image, we simply extend the ones or zeros across all the channels of the filters of the layer. However, for an RGB image, it is a bit different as we must perverse an ordering on the channels. As seen, when predicting pixels, for a given spatial position we predict the values for all the channels before moving on to the next position.
So, when predicting a pixel in the red channel, we will use all the previously predicted pixels in all the channels but will not use the current pixel in the red channel.
Similarly, when we predict a pixel in the green channel, we will use all the previously predicted pixels in all the channels and additionally the previously predicted red channel pixel.
When we predict a pixel in the blue channel, we will use all the previously predicted pixels in all the channels and the previously predicted red and green channel pixels.
To do this, for a given layer, we split the filters into groups of 3, each one corresponding to each channel. The first filter in the layer will correspond to the red channel, the second to the green and the third to the blue. This then repeats for all filters. When for example, a red filter in any layer operates on an input it will produce an output channel that will be classed as red. This is shown in the figure below.
For a given layer that for example produces 12 output channels, we will obtain 4 output feature maps for the red channel, 4 for the green channel and 4 for the blue channel.
The masks (specifically type A) are set up so that a red filter cannot use any channels corresponding to the current pixel it is predicting. A green filter can use the current pixel value from all the red feature maps, and the blue can use all the pixel values of the green and blue feature maps.
A type A mask is shown below in the figure. The connections correspond to the central pixel only. As can be seen, the green filters have access to the red feature maps and the blue filters have access to both the green and red feature maps.
The final layer will output a probability distribution across all the channels. So we will have a depth corresponding to the number of pixel values a pixel can take multiplied by 3. The first feature map contains the probability of each pixel in the red channel being a zero, the second feature map contains the probability of each pixel in the green channel being a zero etc.
We can implement colour masks in the following code:
#Kernel shape is KHxKWxDepthxNum Filters
kernel_shape = self.conv_layer.kernel.get_shape()
_, _, num_in_channels, num_filters = kernel_shape
mask = np.zeros(shape=kernel_shape)
#Initally flip the mask to the shape Num FiltersxDepthxKHxKW to make processing simpler
mask = np.transpose(mask, axes=(3, 2, 0, 1))
#Set half - 1 rows to 1.0s
mask[..., :kernel_shape[0] // 2, :] = 1.0
#Set the half rows tp 1.0 up to the middle - 1 column
mask[..., kernel_shape[0] // 2, :kernel_shape[1] // 2] = 1.0
# Adapted from https://github.com/rampage644/wavenet/blob/master/wavenet/models.py
def bmask(i_out: int, i_in: int) -> np.ndarray:
cout_idx = np.expand_dims(np.arange(num_filters) % 3 == i_out, 1)
cin_idx = np.expand_dims(np.arange(num_in_channels) % 3 == i_in, 0)
a1, a2 = np.broadcast_arrays(cout_idx, cin_idx)
return a1 * a2
mask[bmask(1, 0), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
mask[bmask(2, 0), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
mask[bmask(2, 1), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
if self.mask_type == "B":
for i in range(3):
mask[bmask(i, i), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
mask = np.transpose(mask, axes=(2, 3, 1, 0))
We keep the initial code the same, and just change the way we assign a value to the central pixel. The bmask function creates a mini mask that enables filters to be given access to pixel values of input channels.
For example, assuming we have 6 input channels and 6 output channels, bmask(1, 0) will produce this:
[[False, False, False, False, False, False],
[True, False, False, True, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
[True, False, False, True, False, False],
[False, False, False, False, False, False]]
This will enable the green filters to access the central values of the red input channels. The first and fourth channels of the input will be produced by red filters, and the second and fifth filters of the current layer will be green filters. Therefore, when applied in the way shown below, it would set the central value of the first and fourth channels of the second and fifth filters to 1.
mask[bmask(1, 0), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
The first 3 lines below enable the green filters to use all the values from the red channels and the blue channels to use the values from the red and green channels. The next line enables all the filters to use the value from their own corresponding channels if the mask is of type B.
mask[bmask(1, 0), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
mask[bmask(2, 0), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
mask[bmask(2, 1), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
if self.mask_type == "B":
for i in range(3):
mask[bmask(i, i), kernel_shape[0] // 2, kernel_shape[1] // 2] = 1.0
Mixture Distributions
As can be seen from the code, one downside of the pixelCNN is that it has to output a probability value for each possible pixel value. To solve this deficiency we can enable the network to output a mixture distribution instead.
A mixture distribution, as the name states is a mixture of 2 or more distributions. We have a categorial distribution that denotes the probability of choosing each of the distributions included in the mix. To sample from the distribution, we first sample from the categorial distribution to choose a particular distribution. Then sample from the chosen distribution in the normal way. This way we can create complex distributions with fewer parameters.
For example, if we have a mixture distribution of 3 normal distributions, we would only need 8 parameters, 2 (variance and mean) for each of the normal distributions and 2 for the categorial distribution. This would be in comparison to the 255 parameters that define a categorical distribution over the entire number of possible values a pixel could take,
We can create a pixelCNN in this way, where the output is a mixture distribution. We output the log-likelihood of the image under the mixture distribution. I.e., how probable the observed image input (train data) is if we use the distribution output of the model to make a prediction.
Then we use the negative log-likelihood as the loss function so that the likelihood is maximised as we train the network. This means we optimize the network so that the probability of obtaining the observed image input (train data) is maximised if we use the distribution output of the model to make a prediction.
After training, the distribution outputted from the model can be sampled to generate images.
This is very simple to implement, as it is baked into the tensorflow probability library. The code below shows a mixture distribution pixel CNN with 5 logistic distributions in the mixture. We take the log probability as the output and then minimize the negative log-likelihood as the loss.
We can easily then generate new images by sampling from the distribution as shown.
import tensorflow_probability as tfp
N_COMPONENTS = 5
IMAGE_SIZE = 32
EPOCHS = 50
BATCH_SIZE = 128
dist = tfp.distributions.PixelCNN(
image_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
num_resnet=1,
num_hierarchies=2,
num_filters=32,
num_logistic_mix=N_COMPONENTS,
dropout_p=0.3
)
image_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))
log_prob = dist.log_prob(image_input)
pixelcnn = models.Model(inputs=image_input, outputs=log_prob)
pixelcnn.add_loss(-tf.reduce_mean(log_prob))
pixelcnn.compile(optimizer=optimizers.Adam(0.001))
pixelcnn.fit(
input_data,
batch_size=BATCH_SIZE,
epochs=EPOCHS
)
# Sample 10 from the distribution that the model outputs
dist.sample(10).numpy()
Conclusions
This blog overviewed the architecture and implemented the generative network called pixel CNN. We found that it was able to successfully generate new image instances from a structure of zeros. We also outlined one of the downsides to this architecture and explained how we can overcome it.
The full code can be found on GitHub
- Generative Deep Learning, David Foster
- https://arxiv.org/abs/1601.06759v3
- https://sergeiturukin.com/2017/02/22/pixelcnn.html
- https://paperswithcode.com/method/pixelcnn