# Tutorial 11 (JAX): Normalizing Flows for image modeling¶

**Note:** This notebook is written in JAX+Flax. It is a 1-to-1 translation of the original notebook written in PyTorch+PyTorch Lightning with almost identical results. For an introduction to JAX, check out our Tutorial 2 (JAX): Introduction to JAX+Flax. Further, throughout the notebook, we comment on major differences to the PyTorch version and provide explanations for the major
parts of the JAX code.

**Speed comparison**: We note the training times for all models in the PyTorch and the JAX implementation below (PyTorch v1.11, JAX v0.3.13). The models were trained on the same hardware (NVIDIA RTX3090, 24 core CPU) and we slightly adjusted the tutorials to use the exact same training settings (200 epochs, data loading parameters, evaluation schedule, etc.). Overall, the JAX implementation is *2.0-2.5x faster* than PyTorch! Different network architectures may have different speedups.

Models |
PyTorch |
JAX |
---|---|---|

MNIST Flow - Simple |
2hrs 37min 29sec |
1hrs 17min 59sec |

MNIST Flow - VarDeq |
3hrs 25min 10sec |
1hrs 36min 56sec |

MNIST Flow - Multiscale |
2hrs 17min 10sec |
57min 57sec |

In this tutorial, we will take a closer look at complex, deep normalizing flows. The most popular, current application of deep normalizing flows is to model datasets of images. As for other generative models, images are a good domain to start working on because (1) CNNs are widely studied and strong models exist, (2) images are high-dimensional and complex, and (3) images are discrete integers. In this tutorial, we will review current advances in normalizing flows for image modeling, and get hands-on experience on coding normalizing flows. Note that normalizing flows are commonly parameter heavy and therefore computationally expensive. We will use relatively simple and shallow flows to save computational cost and allow you to run the notebook on CPU, but keep in mind that a simple way to improve the scores of the flows we study here is to make them deeper. The first cell imports our usual libraries.

```
[1]:
```

```
## Standard libraries
import os
import math
import time
import json
import numpy as np
from typing import Sequence
## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
## Progress bar
from tqdm.notebook import tqdm
## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
## JAX
import jax
import jax.numpy as jnp
from jax import random
## Flax (NN in JAX)
try:
import flax
except ModuleNotFoundError: # Install flax if missing
!pip install --quiet flax
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
## Optax (Optimizers in JAX)
try:
import optax
except ModuleNotFoundError: # Install optax if missing
!pip install --quiet optax
import optax
## PyTorch Data Loading
import torch
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.datasets import MNIST
# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)
DATASET_PATH = "../../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../../saved_models/tutorial11_jax"
# Seeding for random operations
main_rng = random.PRNGKey(42)
print("Device:", jax.devices()[0])
```

```
/tmp/ipykernel_3604624/1911608048.py:13: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
set_matplotlib_formats('svg', 'pdf') # For export
```

```
Device: gpu:0
```

Again, we have a few pretrained models. We download them below to the specified path above.

```
[2]:
```

```
"""
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/JAX/tutorial11/"
# Files to download
pretrained_files = ["MNISTFlow_simple.ckpt", "MNISTFlow_vardeq.ckpt", "MNISTFlow_multiscale.ckpt",
"MNISTFlow_simple_results.json", "MNISTFlow_vardeq_results.json", "MNISTFlow_multiscale_results.json"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name)
if not os.path.isfile(file_path):
file_url = base_url + file_name
print(f"Downloading {file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path)
except HTTPError as e:
print("Something went wrong. Please contact the author with the full output including the following error:\n", e)
"""
```

```
[2]:
```

```
'\nimport urllib.request\nfrom urllib.error import HTTPError\n# Github URL where saved models are stored for this tutorial\nbase_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/JAX/tutorial11/"\n# Files to download\npretrained_files = ["MNISTFlow_simple.ckpt", "MNISTFlow_vardeq.ckpt", "MNISTFlow_multiscale.ckpt",\n "MNISTFlow_simple_results.json", "MNISTFlow_vardeq_results.json", "MNISTFlow_multiscale_results.json"]\n# Create checkpoint path if it doesn\'t exist yet\nos.makedirs(CHECKPOINT_PATH, exist_ok=True)\n\n# For each file, check whether it already exists. If not, try downloading it.\nfor file_name in pretrained_files:\n file_path = os.path.join(CHECKPOINT_PATH, file_name)\n if not os.path.isfile(file_path):\n file_url = base_url + file_name\n print(f"Downloading {file_url}...")\n try:\n urllib.request.urlretrieve(file_url, file_path)\n except HTTPError as e:\n print("Something went wrong. Please contact the author with the full output including the following error:\n", e)\n'
```

We will use the MNIST dataset in this notebook. MNIST constitutes, despite its simplicity, a challenge for small generative models as it requires the global understanding of an image. At the same time, we can easily judge whether generated images come from the same distribution as the dataset (i.e. represent real digits), or not.

To deal better with the discrete nature of the images, we transform them from a range of 0-1 to a range of 0-255 as integers.

```
[3]:
```

```
# Transformations applied on each image => bring them into a numpy array
# Note that we keep them in the range 0-255 (integers)
def image_to_numpy(img):
img = np.array(img, dtype=np.int32)
img = img[...,None] # Make image [28, 28, 1]
return img
# We need to stack the batch elements
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=image_to_numpy, download=True)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000],
generator=torch.Generator().manual_seed(42))
# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=image_to_numpy, download=True)
# We define a set of data loaders that we can use for various purposes
# Data loader for loading examples throughout the notebook
train_exmp_loader = data.DataLoader(train_set, batch_size=256, shuffle=False, drop_last=False, collate_fn=numpy_collate)
# Actual data loaders for training, validation, and testing
train_data_loader = data.DataLoader(train_set,
batch_size=128,
shuffle=True,
drop_last=True,
collate_fn=numpy_collate,
num_workers=8,
persistent_workers=True)
val_loader = data.DataLoader(val_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4, collate_fn=numpy_collate)
test_loader = data.DataLoader(test_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4, collate_fn=numpy_collate)
```

In addition, we will define below a function to simplify the visualization of images/samples. Some training examples of the MNIST dataset is shown below.

```
[4]:
```

```
def show_imgs(imgs, title=None, row_size=4):
# Form a grid of pictures (we use max. 8 columns)
imgs = np.copy(jax.device_get(imgs))
num_imgs = imgs.shape[0]
is_int = (imgs.dtype==np.int32)
nrow = min(num_imgs, row_size)
ncol = int(math.ceil(num_imgs/nrow))
imgs_torch = torch.from_numpy(imgs).permute(0, 3, 1, 2)
imgs = torchvision.utils.make_grid(imgs_torch, nrow=nrow, pad_value=128 if is_int else 0.5)
np_imgs = imgs.cpu().numpy()
# Plot the grid
plt.figure(figsize=(1.5*nrow, 1.5*ncol))
plt.imshow(np.transpose(np_imgs, (1,2,0)), interpolation='nearest')
plt.axis('off')
if title is not None:
plt.title(title)
plt.show()
plt.close()
show_imgs(np.stack([train_set[i][0] for i in range(8)], axis=0))
```

## Normalizing Flows as generative model¶

In the previous lectures, we have seen Energy-based models, Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) as example of generative models. However, none of them explicitly learn the probability density function \(p(x)\) of the real input data. While VAEs model a lower bound, energy-based models only implicitly learn the probability density. GANs on the other hand provide us a sampling mechanism for generating new data, without offering a likelihood estimate. The generative model we will look at here, called Normalizing Flows, actually models the true data distribution \(p(x)\) and provides us with an exact likelihood estimate. Below, we can visually compare VAEs, GANs and Flows (figure credit - Lilian Weng):

The major difference compared to VAEs is that flows use *invertible* functions \(f\) to map the input data \(x\) to a latent representation \(z\). To realize this, \(z\) must be of the same shape as \(x\). This is in contrast to VAEs where \(z\) is usually much lower dimensional than the original input data. However, an invertible mapping also means that for every data point \(x\), we have a corresponding latent representation \(z\) which allows us to perform
lossless reconstruction (\(z\) to \(x\)). In the visualization above, this means that \(x=x'\) for flows, no matter what invertible function \(f\) and input \(x\) we choose.

Nonetheless, how are normalizing flows modeling a probability density with an invertible function? The answer to this question is the rule for change of variables. Specifically, given a prior density \(p_z(z)\) (e.g. Gaussian) and an invertible function \(f\), we can determine \(p_x(x)\) as follows:

Hence, in order to determine the probability of \(x\), we only need to determine its probability in latent space, and get the derivate of \(f\). Note that this is for a univariate distribution, and \(f\) is required to be invertible and smooth. For a multivariate case, the derivative becomes a Jacobian of which we need to take the determinant. As we usually use the log-likelihood as objective, we write the multivariate term with logarithms below:

Although we now know how a normalizing flow obtains its likelihood, it might not be clear what a normalizing flow does intuitively. For this, we should look from the inverse perspective of the flow starting with the prior probability density \(p_z(z)\). If we apply an invertible function on it, we effectively “transform” its probability density. For instance, if \(f^{-1}(z)=z+1\), we shift the density by one while still remaining a valid probability distribution, and being invertible. We can also apply more complex transformations, like scaling: \(f^{-1}(z)=2z+1\), but there you might see a difference. When you scale, you also change the volume of the probability density, as for example on uniform distributions (figure credit - Eric Jang):

You can see that the height of \(p(y)\) should be lower than \(p(x)\) after scaling. This change in volume represents \(\left|\frac{df(x)}{dx}\right|\) in our equation above, and ensures that even after scaling, we still have a valid probability distribution. We can go on with making our function \(f\) more complex. However, the more complex \(f\) becomes, the harder it will be to find the inverse \(f^{-1}\) of it, and to calculate the log-determinant of the Jacobian \(\log{} \left|\det \frac{df(\mathbf{x})}{d\mathbf{x}}\right|\). An easier trick to stack multiple invertible functions \(f_{1,...,K}\) after each other, as all together, they still represent a single, invertible function. Using multiple, learnable invertible functions, a normalizing flow attempts to transform \(p_z(z)\) slowly into a more complex distribution which should finally be \(p_x(x)\). We visualize the idea below (figure credit - Lilian Weng):

Starting from \(z_0\), which follows the prior Gaussian distribution, we sequentially apply the invertible functions \(f_1,f_2,...,f_K\), until \(z_K\) represents \(x\). Note that in the figure above, the functions \(f\) represent the inverted function from \(f\) we had above (here: \(f:Z\to X\), above: \(f:X\to Z\)). This is just a different notation and has no impact on the actual flow design because all \(f\) need to be invertible anyways. When we estimate the log likelihood of a data point \(x\) as in the equations above, we run the flows in the opposite direction than visualized above. Multiple flow layers have been proposed that use a neural network as learnable parameters, such as the planar and radial flow. However, we will focus here on flows that are commonly used in image modeling, and will discuss them in the rest of the notebook along with the details of how to train a normalizing flow.

## Normalizing Flows on images¶

To become familiar with normalizing flows, especially for the application of image modeling, it is best to discuss the different elements in a flow along with the implementation. As a general concept, we want to build a normalizing flow that maps an input image (here MNIST) to an equally sized latent space:

As a first step, we will implement a template of a normalizing flow. During training and validation, a normalizing flow performs density estimation in the forward direction. For this, we apply a series of flow transformations on the input \(x\) and estimate the probability of the input by determining the probability of the transformed point \(z\) given a prior, and the change of volume caused by the transformations. During inference, we can do both density estimation and sampling new
points by inverting the flow transformations. Therefore, we define a function `_get_likelihood`

which performs density estimation, and `sample`

to generate new examples.

The standard metric used in generative models, and in particular normalizing flows, is bits per dimensions (bpd). Bpd is motivated from an information theory perspective and describes how many bits we would need to encode a particular example in our modeled distribution. The less bits we need, the more likely the example in our distribution. When we test for the bits per dimension of our test dataset, we can judge whether our model generalizes to new samples of the dataset and didn’t memorize the training dataset. In order to calculate the bits per dimension score, we can rely on the negative log-likelihood and change the log base (as bits are binary while NLL is usually exponential):

where \(d_1,...,d_K\) are the dimensions of the input. For images, this would be the height, width and channel number. We divide the log likelihood by these extra dimensions to have a metric which we can compare for different image resolutions. In the original image space, MNIST examples have a bits per dimension score of 8 (we need 8 bits to encode each pixel as there are 256 possible values).

```
[5]:
```

```
class ImageFlow(nn.Module):
flows : Sequence[nn.Module] # A list of flows (each a nn.Module) that should be applied on the images.
import_samples : int = 8 # Number of importance samples to use during testing (see explanation below).
def __call__(self, x, rng, testing=False):
if not testing:
bpd, rng = self._get_likelihood(x, rng)
else:
# Perform importance sampling during testing => estimate likelihood M times for each image
img_ll, rng = self._get_likelihood(x.repeat(self.import_samples, 0),
rng,
return_ll=True)
img_ll = img_ll.reshape(-1, self.import_samples)
# To average the probabilities, we need to go from log-space to exp, and back to log.
# Logsumexp provides us a stable implementation for this
img_ll = jax.nn.logsumexp(img_ll, axis=-1) - np.log(self.import_samples)
# Calculate final bpd
bpd = -img_ll * np.log2(np.exp(1)) / np.prod(x.shape[1:])
bpd = bpd.mean()
return bpd, rng
def encode(self, imgs, rng):
# Given a batch of images, return the latent representation z and ldj of the transformations
z, ldj = imgs, jnp.zeros(imgs.shape[0])
for flow in self.flows:
z, ldj, rng = flow(z, ldj, rng, reverse=False)
return z, ldj, rng
def _get_likelihood(self, imgs, rng, return_ll=False):
"""
Given a batch of images, return the likelihood of those.
If return_ll is True, this function returns the log likelihood of the input.
Otherwise, the ouptut metric is bits per dimension (scaled negative log likelihood)
"""
z, ldj, rng = self.encode(imgs, rng)
log_pz = jax.scipy.stats.norm.logpdf(z).sum(axis=(1,2,3))
log_px = ldj + log_pz
nll = -log_px
# Calculating bits per dimension
bpd = nll * np.log2(np.exp(1)) / np.prod(imgs.shape[1:])
return (bpd.mean() if not return_ll else log_px), rng
def sample(self, img_shape, rng, z_init=None):
"""
Sample a batch of images from the flow.
"""
# Sample latent representation from prior
if z_init is None:
rng, normal_rng = random.split(rng)
z = random.normal(normal_rng, shape=img_shape)
else:
z = z_init
# Transform z to x by inverting the flows
ldj = jnp.zeros(img_shape[0])
for flow in reversed(self.flows):
z, ldj, rng = flow(z, ldj, rng, reverse=True)
return z, rng
```

The test step differs from the training and validation step in that it makes use of importance sampling. We will discuss the motiviation and details behind this after understanding how flows model discrete images in continuous space.

### Dequantization¶

Normalizing flows rely on the rule of change of variables, which is naturally defined in continuous space. Applying flows directly on discrete data leads to undesired density models where arbitrarly high likelihood are placed on a few, particular values. See the illustration below:

The black points represent the discrete points, and the green volume the density modeled by a normalizing flow in continuous space. The flow would continue to increase the likelihood for \(x=0,1,2,3\) while having no volume on any other point. Remember that in continuous space, we have the constraint that the overall volume of the probability density must be 1 (\(\int p(x)dx=1\)). Otherwise, we don’t model a probability distribution anymore. However, the discrete points \(x=0,1,2,3\) represent delta peaks with no width in continuous space. This is why the flow can place an infinite high likelihood on these few points while still representing a distribution in continuous space. Nonetheless, the learned density does not tell us anything about the distribution among the discrete points, as in discrete space, the likelihoods of those four points would have to sum to 1, not to infinity.

To prevent such degenerated solutions, a common solution is to add a small amount of noise to each discrete value, which is also referred to as dequantization. Considering \(x\) as an integer (as it is the case for images), the dequantized representation \(v\) can be formulated as \(v=x+u\) where \(u\in[0,1)^D\). Thus, the discrete value \(1\) is modeled by a distribution over the interval \([1.0, 2.0)\), the value \(2\) by an volume over \([2.0, 3.0)\), etc. Our objective of modeling \(p(x)\) becomes:

with \(q(u|x)\) being the noise distribution. For now, we assume it to be uniform, which can also be written as \(p(x)=\mathbb{E}_{u\sim U(0,1)^D}\left[p(x+u) \right]\).

In the following, we will implement Dequantization as a flow transformation itself. After adding noise to the discrete values, we additionally transform the volume into a Gaussian-like shape. This is done by scaling \(x+u\) between \(0\) and \(1\), and applying the invert of the sigmoid function \(\sigma(z)^{-1} = \log z - \log 1-z\). If we would not do this, we would face two problems:

The input is scaled between 0 and 256 while the prior distribution is a Gaussian with mean \(0\) and standard deviation \(1\). In the first iterations after initializing the parameters of the flow, we would have extremely low likelihoods for large values like \(256\). This would cause the training to diverge instantaneously.

As the output distribution is a Gaussian, it is beneficial for the flow to have a similarly shaped input distribution. This will reduce the modeling complexity that is required by the flow.

Overall, we can implement dequantization as follows:

```
[6]:
```

```
class Dequantization(nn.Module):
alpha : float = 1e-5 # Small constant that is used to scale the original input for numerical stability.
quants : int = 256 # Number of possible discrete values (usually 256 for 8-bit image)
def __call__(self, z, ldj, rng, reverse=False):
if not reverse:
z, ldj, rng = self.dequant(z, ldj, rng)
z, ldj = self.sigmoid(z, ldj, reverse=True)
else:
z, ldj = self.sigmoid(z, ldj, reverse=False)
z = z * self.quants
ldj += np.log(self.quants) * np.prod(z.shape[1:])
z = jnp.floor(z)
z = jax.lax.clamp(min=0., x=z, max=self.quants-1.).astype(jnp.int32)
return z, ldj, rng
def sigmoid(self, z, ldj, reverse=False):
# Applies an invertible sigmoid transformation
if not reverse:
ldj += (-z-2*jax.nn.softplus(-z)).sum(axis=[1,2,3])
z = nn.sigmoid(z)
else:
z = z * (1 - self.alpha) + 0.5 * self.alpha # Scale to prevent boundaries 0 and 1
ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])
ldj += (-jnp.log(z) - jnp.log(1-z)).sum(axis=[1,2,3])
z = jnp.log(z) - jnp.log(1-z)
return z, ldj
def dequant(self, z, ldj, rng):
# Transform discrete values to continuous volumes
z = z.astype(jnp.float32)
rng, uniform_rng = random.split(rng)
z = z + random.uniform(uniform_rng, z.shape)
z = z / self.quants
ldj -= np.log(self.quants) * np.prod(z.shape[1:])
return z, ldj, rng
```

A good check whether a flow is correctly implemented or not, is to verify that it is invertible. Hence, we will dequantize a randomly chosen training image, and then quantize it again. We would expect that we would get the exact same image out:

```
[7]:
```

```
## Testing invertibility of dequantization layer
orig_img = train_set[0][0][None] # Example image
ldj = jnp.zeros(1,)
dequant_module = Dequantization()
dequant_rng = random.PRNGKey(5)
deq_img, ldj, dequant_rng = dequant_module(orig_img, ldj, dequant_rng, reverse=False)
reconst_img, ldj, dequant_rng = dequant_module(deq_img, ldj, dequant_rng, reverse=True)
d1, d2 = jnp.where(orig_img.squeeze() != reconst_img.squeeze())
if len(d1) != 0:
print("Dequantization was not invertible.")
for i in range(d1.shape[0]):
print("Original value:", orig_img[0,d1[i],d2[i],0].item())
print("Reconstructed value:", reconst_img[0,d1[i],d2[i],0].item())
else:
print("Successfully inverted dequantization")
# Layer is not strictly invertible due to float precision constraints
# assert (orig_img == reconst_img).all().item()
```

```
Dequantization was not invertible.
Original value: 0
Reconstructed value: 1
```

In contrast to our expectation, the test fails. However, this is no reason to doubt our implementation here as only one single value is not equal to the original. This is caused due to numerical inaccuracies in the sigmoid invert. While the input space to the inverted sigmoid is scaled between 0 and 1, the output space is between \(-\infty\) and \(\infty\). And as we use 32 bits to represent the numbers (in addition to applying logs over and over again), such inaccuries can occur and should not be worrisome. Nevertheless, it is good to be aware of them, and can be improved by using a double tensor (float64).

Finally, we can take our dequantization and actually visualize the distribution it transforms the discrete values into:

```
[8]:
```

```
def visualize_dequantization(quants, prior=None):
"""
Function for visualizing the dequantization values of discrete values in continuous space
"""
# Prior over discrete values. If not given, a uniform is assumed
if prior is None:
prior = np.ones(quants, dtype=np.float32) / quants
prior = prior / prior.sum() * quants # In the following, we assume 1 for each value means uniform distribution
inp = jnp.arange(-4, 4, 0.01).reshape(-1, 1, 1, 1) # Possible continuous values we want to consider
ldj = jnp.zeros(inp.shape[0])
dequant_module = Dequantization(quants=quants)
# Invert dequantization on continuous values to find corresponding discrete value
out, ldj, _ = dequant_module(inp, ldj, rng=None, reverse=True)
inp, out, prob = inp.squeeze(), out.squeeze(), jnp.exp(ldj)
prob = prob * prior[out] # Probability scaled by categorical prior
# Plot volumes and continuous distribution
sns.set_style("white")
fig = plt.figure(figsize=(6,3))
x_ticks = []
for v in np.unique(out):
indices = np.where(out==v)
color = to_rgb(f"C{v}")
plt.fill_between(inp[indices], prob[indices], np.zeros(indices[0].shape[0]), color=color+(0.5,), label=str(v))
plt.plot([inp[indices[0][0]]]*2, [0, prob[indices[0][0]]], color=color)
plt.plot([inp[indices[0][-1]]]*2, [0, prob[indices[0][-1]]], color=color)
x_ticks.append(inp[indices[0][0]])
x_ticks.append(inp.max())
plt.xticks(x_ticks, [f"{x:.1f}" for x in x_ticks])
plt.plot(inp,prob, color=(0.0,0.0,0.0))
# Set final plot properties
plt.ylim(0, prob.max()*1.1)
plt.xlim(inp.min(), inp.max())
plt.xlabel("z")
plt.ylabel("Probability")
plt.title(f"Dequantization distribution for {quants} discrete values")
plt.legend()
plt.show()
plt.close()
visualize_dequantization(quants=8)
```

The visualized distribution show the sub-volumes that are assigned to the different discrete values. The value \(0\) has its volume between \([-\infty, -1.9)\), the value \(1\) is represented by the interval \([-1.9, -1.1)\), etc. The volume for each discrete value has the same probability mass. That’s why the volumes close to the center (e.g. 3 and 4) have a smaller area on the z-axis as others (\(z\) is being used to denote the output of the whole dequantization flow).

Effectively, the consecutive normalizing flow models discrete images by the following objective:

Although normalizing flows are exact in likelihood, we have a lower bound. Specifically, this is an example of the Jensen inequality because we need to move the log into the expectation so we can use Monte-carlo estimates. In general, this bound is considerably smaller than the ELBO in variational autoencoders. Actually, we can reduce the bound ourselves by estimating the expectation not by one, but by \(M\) samples. In other words, we can apply importance sampling which leads to the following inequality:

The importance sampling \(\frac{1}{M} \sum_{m=1}^{M} \frac{p(x+u_m)}{q(u_m|x)}\) becomes \(\mathbb{E}_{u\sim q(u|x)}\left[\frac{p(x+u)}{q(u|x)} \right]\) if \(M\to \infty\), so that the more samples we use, the tighter the bound is. During testing, we can make use of this property and have it implemented in `test_step`

in `ImageFlow`

. In theory, we could also use this tighter bound during training. However, related work has shown that this does not necessarily lead to an
improvement given the additional computational cost, and it is more efficient to stick with a single estimate [5].

### Variational Dequantization¶

Dequantization uses a uniform distribution for the noise \(u\) which effectively leads to images being represented as hypercubes (cube in high dimensions) with sharp borders. However, modeling such sharp borders is not easy for a flow as it uses smooth transformations to convert it into a Gaussian distribution.

Another way of looking at it is if we change the prior distribution in the previous visualization. Imagine we have independent Gaussian noise on pixels which is commonly the case for any real-world taken picture. Therefore, the flow would have to model a distribution as above, but with the individual volumes scaled as follows:

```
[9]:
```

```
visualize_dequantization(quants=8, prior=np.array([0.075, 0.2, 0.4, 0.2, 0.075, 0.025, 0.0125, 0.0125]))
```

Transforming such a probability into a Gaussian is a difficult task, especially with such hard borders. Dequantization has therefore been extended to more sophisticated, learnable distributions beyond uniform in a variational framework. In particular, if we remember the learning objective \(\log p(x) = \log \mathbb{E}_{u}\left[\frac{p(x+u)}{q(u|x)} \right]\), the uniform distribution can be replaced by a learned distribution \(q_{\theta}(u|x)\) with support over \(u\in[0,1)^D\). This approach is called Variational Dequantization and has been proposed by Ho et al. [3]. How can we learn such a distribution? We can use a second normalizing flow that takes \(x\) as external input and learns a flexible distribution over \(u\). To ensure a support over \([0,1)^D\), we can apply a sigmoid activation function as final flow transformation.

Inheriting the original dequantization class, we can implement variational dequantization as follows:

```
[10]:
```

```
class VariationalDequantization(Dequantization):
var_flows : Sequence[nn.Module] = None # A list of flow transformations to use for modeling q(u|x)
def dequant(self, z, ldj, rng):
z = z.astype(jnp.float32)
img = (z / 255.0) * 2 - 1 # We condition the flows on x, i.e. the original image
# Prior of u is a uniform distribution as before
# As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.
rng, uniform_rng = random.split(rng)
deq_noise = random.uniform(uniform_rng, z.shape)
deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)
if self.var_flows is not None:
for flow in self.var_flows:
deq_noise, ldj, rng = flow(deq_noise, ldj, rng, reverse=False, orig_img=img)
deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)
# After the flows, apply u as in standard dequantization
z = (z + deq_noise) / 256.0
ldj -= np.log(256.0) * np.prod(z.shape[1:])
return z, ldj, rng
```

Variational dequantization can be used as a substitute for dequantization. We will compare dequantization and variational dequantization in later experiments.

### Coupling layers¶

Next, we look at possible transformations to apply inside the flow. A recent popular flow layer, which works well in combination with deep neural networks, is the coupling layer introduced by Dinh et al. [1]. The input \(z\) is arbitrarily split into two parts, \(z_{1:j}\) and \(z_{j+1:d}\), of which the first remains unchanged by the flow. Yet, \(z_{1:j}\) is used to parameterize the transformation for the second part, \(z_{j+1:d}\). Various transformations have been proposed in recent time [3,4], but here we will settle for the simplest and most efficient one: affine coupling. In this coupling layer, we apply an affine transformation by shifting the input by a bias \(\mu\) and scale it by \(\sigma\). In other words, our transformation looks as follows:

The functions \(\mu\) and \(\sigma\) are implemented as a shared neural network, and the sum and multiplication are performed element-wise. The LDJ is thereby the sum of the logs of the scaling factors: \(\sum_i \left[\log \sigma_{\theta}(z_{1:j})\right]_i\). Inverting the layer can as simply be done as subtracting the bias and dividing by the scale:

We can also visualize the coupling layer in form of a computation graph, where \(z_1\) represents \(z_{1:j}\), and \(z_2\) represents \(z_{j+1:d}\):

In our implementation, we will realize the splitting of variables as masking. The variables to be transformed, \(z_{j+1:d}\), are masked when passing \(z\) to the shared network to predict the transformation parameters. When applying the transformation, we mask the parameters for \(z_{1:j}\) so that we have an identity operation for those variables:

```
[11]:
```

```
class CouplingLayer(nn.Module):
network : nn.Module # NN to use in the flow for predicting mu and sigma
mask : np.ndarray # Binary mask where 0 denotes that the element should be transformed, and 1 not.
c_in : int # Number of input channels
def setup(self):
self.scaling_factor = self.param('scaling_factor',
nn.initializers.zeros,
(self.c_in,))
def __call__(self, z, ldj, rng, reverse=False, orig_img=None):
"""
Inputs:
z - Latent input to the flow
ldj - The current ldj of the previous flows.
The ldj of this layer will be added to this tensor.
rng - PRNG state
reverse - If True, we apply the inverse of the layer.
orig_img (optional) - Only needed in VarDeq. Allows external
input to condition the flow on (e.g. original image)
"""
# Apply network to masked input
z_in = z * self.mask
if orig_img is None:
nn_out = self.network(z_in)
else:
nn_out = self.network(jnp.concatenate([z_in, orig_img], axis=-1))
s, t = nn_out.split(2, axis=-1)
# Stabilize scaling output
s_fac = jnp.exp(self.scaling_factor).reshape(1, 1, 1, -1)
s = nn.tanh(s / s_fac) * s_fac
# Mask outputs (only transform the second part)
s = s * (1 - self.mask)
t = t * (1 - self.mask)
# Affine transformation
if not reverse:
# Whether we first shift and then scale, or the other way round,
# is a design choice, and usually does not have a big impact
z = (z + t) * jnp.exp(s)
ldj += s.sum(axis=[1,2,3])
else:
z = (z * jnp.exp(-s)) - t
ldj -= s.sum(axis=[1,2,3])
return z, ldj, rng
```

For stabilization purposes, we apply a \(\tanh\) activation function on the scaling output. This prevents sudden large output values for the scaling that can destabilize training. To still allow scaling factors smaller or larger than -1 and 1 respectively, we have a learnable parameter per dimension, called `scaling_factor`

. This scales the tanh to different limits. Below, we visualize the effect of the scaling factor on the output activation of the scaling terms:

```
[12]:
```

```
x = jnp.arange(-5,5,0.01)
scaling_factors = [0.5, 1, 2]
sns.set()
fig, ax = plt.subplots(1, 3, figsize=(12,3))
for i, scale in enumerate(scaling_factors):
y = nn.tanh(x / scale) * scale
ax[i].plot(x, y)
ax[i].set_title("Scaling factor: " + str(scale))
ax[i].set_ylim(-3, 3)
plt.subplots_adjust(wspace=0.4)
sns.reset_orig()
plt.show()
```

Coupling layers generalize to any masking technique we could think of. However, the most common approach for images is to split the input \(z\) in half, using a checkerboard mask or channel mask. A checkerboard mask splits the variables across the height and width dimensions and assigns each other pixel to \(z_{j+1:d}\). Thereby, the mask is shared across channels. In contrast, the channel mask assigns half of the channels to \(z_{j+1:d}\), and the other half to \(z_{1:j+1}\). Note that when we apply multiple coupling layers, we invert the masking for each other layer so that each variable is transformed a similar amount of times.

Let’s implement a function that creates a checkerboard mask and a channel mask for us:

```
[13]:
```

```
def create_checkerboard_mask(h, w, invert=False):
x, y = jnp.arange(h, dtype=jnp.int32), jnp.arange(w, dtype=jnp.int32)
xx, yy = jnp.meshgrid(x, y, indexing='ij')
mask = jnp.fmod(xx + yy, 2)
mask = mask.astype(jnp.float32).reshape(1, h, w, 1)
if invert:
mask = 1 - mask
return mask
def create_channel_mask(c_in, invert=False):
mask = jnp.concatenate([
jnp.ones((c_in//2,), dtype=jnp.float32),
jnp.zeros((c_in-c_in//2,), dtype=jnp.float32)
])
mask = mask.reshape(1, 1, 1, c_in)
if invert:
mask = 1 - mask
return mask
```

We can also visualize the corresponding masks for an image of size \(8\times 8\times 2\) (2 channels):

```
[14]:
```

```
checkerboard_mask = create_checkerboard_mask(h=8, w=8).repeat(2, -1)
channel_mask = jnp.resize(create_channel_mask(c_in=2), (1,8,8,2))
show_imgs(checkerboard_mask.swapaxes(0, 3), "Checkerboard mask")
show_imgs(channel_mask.swapaxes(0, 3), "Channel mask")
```

As a last aspect of coupling layers, we need to decide for the deep neural network we want to apply in the coupling layers. The input to the layers is an image, and hence we stick with a CNN. Because the input to a transformation depends on all transformations before, it is crucial to ensure a good gradient flow through the CNN back to the input, which can be optimally achieved by a ResNet-like architecture. Specifically, we use a Gated ResNet that adds a \(\sigma\)-gate to the skip connection, similarly to the input gate in LSTMs. The details are not necessarily important here, and the network is strongly inspired from Flow++ [3] in case you are interested in building even stronger models.

```
[15]:
```

```
class ConcatELU(nn.Module):
"""
Activation function that applies ELU in both direction (inverted and plain).
Allows non-linearity while providing strong gradients for any input (important for final convolution)
"""
def __call__(self, x):
return jnp.concatenate([nn.elu(x), nn.elu(-x)], axis=-1)
class GatedConv(nn.Module):
""" This module applies a two-layer convolutional ResNet block with input gate """
c_in : int # Number of input channels
c_hidden : int # Number of hidden dimensions
@nn.compact
def __call__(self, x):
out = nn.Sequential([
ConcatELU(),
nn.Conv(self.c_hidden, kernel_size=(3, 3)),
ConcatELU(),
nn.Conv(2*self.c_in, kernel_size=(1, 1))
])(x)
val, gate = out.split(2, axis=-1)
return x + val * nn.sigmoid(gate)
class GatedConvNet(nn.Module):
c_hidden : int # Number of hidden dimensions to use within the network
c_out : int # Number of output channels
num_layers : int = 3 # Number of gated ResNet blocks to apply
def setup(self):
layers = []
layers += [nn.Conv(self.c_hidden, kernel_size=(3, 3))]
for layer_index in range(self.num_layers):
layers += [GatedConv(self.c_hidden, self.c_hidden),
nn.LayerNorm()]
layers += [ConcatELU(),
nn.Conv(self.c_out, kernel_size=(3, 3),
kernel_init=nn.initializers.zeros)]
self.nn = nn.Sequential(layers)
def __call__(self, x):
return self.nn(x)
```

```
[16]:
```

```
## Test MultiheadAttention implementation
# Example features as input
main_rng, x_rng = random.split(main_rng)
x = random.normal(x_rng, (3, 32, 32, 16))
# Create attention
mh_attn = GatedConvNet(c_hidden=32, c_out=18, num_layers=3)
# Initialize parameters of attention with random key and inputs
main_rng, init_rng = random.split(main_rng)
params = mh_attn.init(init_rng, x)['params']
# Apply attention with parameters on the inputs
out = mh_attn.apply({'params': params}, x)
print('Out', out.shape)
del mh_attn, params
```

```
Out (3, 32, 32, 18)
```

### Training loop¶

Finally, we can add Dequantization, Variational Dequantization and Coupling Layers together to build our full normalizing flow on MNIST images. We apply 8 coupling layers in the main flow, and 4 for variational dequantization if applied. We apply a checkerboard mask throughout the network as with a single channel (black-white images), we cannot apply channel mask. The overall architecture is visualized below.

```
[17]:
```

```
def create_simple_flow(use_vardeq=True):
flow_layers = []
if use_vardeq:
vardeq_layers = [CouplingLayer(network=GatedConvNet(c_out=2, c_hidden=16),
mask=create_checkerboard_mask(h=28, w=28, invert=(i%2==1)),
c_in=1) for i in range(4)]
flow_layers += [VariationalDequantization(var_flows=vardeq_layers)]
else:
flow_layers += [Dequantization()]
for i in range(8):
flow_layers += [CouplingLayer(network=GatedConvNet(c_out=2, c_hidden=32),
mask=create_checkerboard_mask(h=28, w=28, invert=(i%2==1)),
c_in=1)]
flow_model = ImageFlow(flow_layers)
return flow_model
```

For implementing the training loop, we use a similar trainer module as we have done in several tutorials before. Note that we again provide pre-trained models (see later on in the notebook) as normalizing flows are particularly expensive to train. We have also run validation and testing as this can take some time as well with the added importance sampling.

```
[18]:
```

```
class TrainerModule:
def __init__(self, model_name, flow, lr=1e-3, seed=42):
super().__init__()
self.model_name = model_name
self.lr = lr
self.seed = seed
# Create empty model. Note: no parameters yet
self.model = flow
# Prepare logging
self.exmp_imgs = next(iter(train_exmp_loader))[0]
self.log_dir = os.path.join(CHECKPOINT_PATH, self.model_name)
self.logger = SummaryWriter(log_dir=self.log_dir)
# Create jitted training and eval functions
self.create_functions()
# Initialize model
self.init_model()
def create_functions(self):
# Training function
def train_step(state, rng, batch):
imgs, _ = batch
loss_fn = lambda params: self.model.apply({'params': params}, imgs, rng, testing=False)
(loss, rng), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) # Get loss and gradients for loss
state = state.apply_gradients(grads=grads) # Optimizer update step
return state, rng, loss
self.train_step = jax.jit(train_step)
# Eval function, which is separately jitted for validation and testing
def eval_step(state, rng, batch, testing):
return self.model.apply({'params': state.params}, batch[0], rng, testing=testing)
self.eval_step = jax.jit(eval_step, static_argnums=(3,))
def init_model(self):
# Initialize model
self.rng = jax.random.PRNGKey(self.seed)
self.rng, init_rng, flow_rng = jax.random.split(self.rng, 3)
params = self.model.init(init_rng, self.exmp_imgs, flow_rng)['params']
# Initialize learning rate schedule and optimizer
lr_schedule = optax.exponential_decay(
init_value=self.lr,
transition_steps=len(train_data_loader),
decay_rate=0.99,
end_value=0.01*self.lr
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Clip gradients at 1
optax.adam(lr_schedule)
)
# Initialize training state
self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=optimizer)
def train_model(self, train_loader, val_loader, num_epochs=500):
# Train model for defined number of epochs
best_eval = 1e6
for epoch_idx in tqdm(range(1, num_epochs+1)):
self.train_epoch(train_loader, epoch=epoch_idx)
if epoch_idx % 5 == 0:
eval_bpd = self.eval_model(val_loader, testing=False)
self.logger.add_scalar('val/bpd', eval_bpd, global_step=epoch_idx)
if eval_bpd < best_eval:
best_eval = eval_bpd
self.save_model(step=epoch_idx)
self.logger.flush()
def train_epoch(self, data_loader, epoch):
# Train model for one epoch, and log avg loss
avg_loss = 0.
for batch in tqdm(data_loader, leave=False):
self.state, self.rng, loss = self.train_step(self.state, self.rng, batch)
avg_loss += loss
avg_loss /= len(data_loader)
self.logger.add_scalar('train/bpd', avg_loss.item(), global_step=epoch)
def eval_model(self, data_loader, testing=False):
# Test model on all images of a data loader and return avg loss
losses = []
batch_sizes = []
for batch in data_loader:
loss, self.rng = self.eval_step(self.state, self.rng, batch, testing=testing)
losses.append(loss)
batch_sizes.append(batch[0].shape[0])
losses_np = np.stack(jax.device_get(losses))
batch_sizes_np = np.stack(batch_sizes)
avg_loss = (losses_np * batch_sizes_np).sum() / batch_sizes_np.sum()
return avg_loss
def save_model(self, step=0):
# Save current model at certain training iteration
checkpoints.save_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, step=step)
def load_model(self, pretrained=False):
# Load model. We use different checkpoint for pretrained models
if not pretrained:
params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=self.state.params)
else:
params = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'),
target=self.state.params)
self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)
def checkpoint_exists(self):
# Check whether a pretrained model exist for this autoencoder
return os.path.isfile(os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'))
```

```
[19]:
```

```
def train_flow(flow, model_name="MNISTFlow"):
# Create a trainer module with specified hyperparameters
trainer = TrainerModule(model_name, flow)
if not trainer.checkpoint_exists(): # Skip training if pretrained model exists
trainer.train_model(train_data_loader,
val_loader,
num_epochs=200)
trainer.load_model()
val_bpd = trainer.eval_model(val_loader, testing=True)
start_time = time.time()
test_bpd = trainer.eval_model(test_loader, testing=True)
duration = time.time() - start_time
results = {'val': val_bpd, 'test': test_bpd, 'time': duration / len(test_loader) / trainer.model.import_samples}
else:
trainer.load_model(pretrained=True)
with open(os.path.join(CHECKPOINT_PATH, f'{trainer.model_name}_results.json'), 'r') as f:
results = json.load(f)
# Bind parameters to model for easier inference
trainer.model_bd = trainer.model.bind({'params': trainer.state.params})
return trainer, results
```

## Multi-scale architecture¶

One disadvantage of normalizing flows is that they operate on the exact same dimensions as the input. If the input is high-dimensional, so is the latent space, which requires larger computational cost to learn suitable transformations. However, particularly in the image domain, many pixels contain less information in the sense that we could remove them without loosing the semantical information of the image.

Based on this intuition, deep normalizing flows on images commonly apply a multi-scale architecture [1]. After the first \(N\) flow transformations, we split off half of the latent dimensions and directly evaluate them on the prior. The other half is run through \(N\) more flow transformations, and depending on the size of the input, we split it again in half or stop overall at this position. The two operations involved in this setup is `Squeeze`

and `Split`

which we will review more
closely and implement below.

### Squeeze and Split¶

When we want to remove half of the pixels in an image, we have the problem of deciding which variables to cut, and how to rearrange the image. Thus, the squeezing operation is commonly used before split, which divides the image into subsquares of shape \(2\times 2\times C\), and reshapes them into \(1\times 1\times 4C\) blocks. Effectively, we reduce the height and width of the image by a factor of 2 while scaling the number of channels by 4. Afterwards, we can perform the split operation over channels without the need of rearranging the pixels. The smaller scale also makes the overall architecture more efficient. Visually, the squeeze operation should transform the input as follows:

The input of \(4\times 4\times 1\) is scaled to \(2\times 2\times 4\) following the idea of grouping the pixels in \(2\times 2\times 1\) subsquares. Next, let’s try to implement this layer:

```
[20]:
```

```
class SqueezeFlow(nn.Module):
def __call__(self, z, ldj, rng, reverse=False):
B, H, W, C = z.shape
if not reverse:
# Forward direction: H x W x C => H/2 x W/2 x 4C
z = z.reshape(B, H//2, 2, W//2, 2, C)
z = z.transpose((0, 1, 3, 2, 4, 5))
z = z.reshape(B, H//2, W//2, 4*C)
else:
# Reverse direction: H/2 x W/2 x 4C => H x W x C
z = z.reshape(B, H, W, 2, 2, C//4)
z = z.transpose((0, 1, 3, 2, 4, 5))
z = z.reshape(B, H*2, W*2, C//4)
return z, ldj, rng
```

Before moving on, we can verify our implementation by comparing our output with the example figure above:

```
[21]:
```

```
sq_flow = SqueezeFlow()
rand_img = jnp.arange(1,17).reshape(1, 4, 4, 1)
print("Image (before)\n", rand_img.transpose(0, 3, 1, 2)) # Permute for readability
forward_img, _, _ = sq_flow(rand_img, ldj=None, rng=None, reverse=False)
print("\nImage (forward)\n", forward_img)
reconst_img, _, _ = sq_flow(forward_img, ldj=None, rng=None, reverse=True)
print("\nImage (reverse)\n", reconst_img.transpose(0, 3, 1, 2))
```

```
Image (before)
[[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]]
Image (forward)
[[[[ 1 2 5 6]
[ 3 4 7 8]]
[[ 9 10 13 14]
[11 12 15 16]]]]
Image (reverse)
[[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]]
```

The split operation divides the input into two parts, and evaluates one part directly on the prior. So that our flow operation fits to the implementation of the previous layers, we will return the prior probability of the first part as the log determinant jacobian of the layer. It has the same effect as if we would combine all variable splits at the end of the flow, and evaluate them together on the prior.

```
[22]:
```

```
class SplitFlow(nn.Module):
def __call__(self, z, ldj, rng, reverse=False):
if not reverse:
z, z_split = z.split(2, axis=-1)
ldj += jax.scipy.stats.norm.logpdf(z_split).sum(axis=[1,2,3])
else:
z_split = random.normal(rng, z.shape)
z = jnp.concatenate([z, z_split], axis=-1)
ldj -= jax.scipy.stats.norm.logpdf(z_split).sum(axis=[1,2,3])
return z, ldj, rng
```

### Building a multi-scale flow¶

After defining the squeeze and split operation, we are finally able to build our own multi-scale flow. Deep normalizing flows such as Glow and Flow++ [2,3] often apply a split operation directly after squeezing. However, with shallow flows, we need to be more thoughtful about where to place the split operation as we need at least a minimum amount of transformations on each variable. Our setup is inspired by the original RealNVP architecture [1] which is shallower than other, more recent state-of-the-art architectures.

Hence, for the MNIST dataset, we will apply the first squeeze operation after two coupling layers, but don’t apply a split operation yet. Because we have only used two coupling layers and each the variable has been only transformed once, a split operation would be too early. We apply two more coupling layers before finally applying a split flow and squeeze again. The last four coupling layers operate on a scale of \(7\times 7\times 8\). The full flow architecture is shown below.

Note that while the feature maps inside the coupling layers reduce with the height and width of the input, the increased number of channels is not directly considered. To counteract this, we increase the hidden dimensions for the coupling layers on the squeezed input. The dimensions are often scaled by 2 as this approximately increases the computation cost by 4 canceling with the squeezing operation. However, we will choose the hidden dimensionalities \(32, 48, 64\) for the three scales respectively to keep the number of parameters reasonable and show the efficiency of multi-scale architectures.

```
[23]:
```

```
def create_multiscale_flow():
flow_layers = []
vardeq_layers = [CouplingLayer(network=GatedConvNet(c_out=2, c_hidden=16),
mask=create_checkerboard_mask(h=28, w=28, invert=(i%2==1)),
c_in=1) for i in range(4)]
flow_layers += [VariationalDequantization(var_flows=vardeq_layers)]
flow_layers += [CouplingLayer(network=GatedConvNet(c_out=2, c_hidden=32),
mask=create_checkerboard_mask(h=28, w=28, invert=(i%2==1)),
c_in=1) for i in range(2)]
flow_layers += [SqueezeFlow()]
for i in range(2):
flow_layers += [CouplingLayer(network=GatedConvNet(c_out=8, c_hidden=48),
mask=create_channel_mask(c_in=4, invert=(i%2==1)),
c_in=4)]
flow_layers += [SplitFlow(),
SqueezeFlow()]
for i in range(4):
flow_layers += [CouplingLayer(network=GatedConvNet(c_out=16, c_hidden=64),
mask=create_channel_mask(c_in=8, invert=(i%2==1)),
c_in=8)]
flow_model = ImageFlow(flow_layers)
return flow_model
```

## Analysing the flows¶

In the last part of the notebook, we will train all the models we have implemented above, and try to analyze the effect of the multi-scale architecture and variational dequantization.

### Training flow variants¶

Before we can analyse the flow models, we need to train them first. We provide pre-trained models that contain the validation and test performance, and run-time information. As flow models are computationally expensive, we advice you to rely on those pretrained models for a first run through the notebook.

```
[24]:
```

```
flow_dict = {"simple": {}, "vardeq": {}, "multiscale": {}}
flow_dict["simple"]["model"], flow_dict["simple"]["result"] = train_flow(create_simple_flow(use_vardeq=False), model_name="MNISTFlow_simple")
flow_dict["vardeq"]["model"], flow_dict["vardeq"]["result"] = train_flow(create_simple_flow(use_vardeq=True), model_name="MNISTFlow_vardeq")
flow_dict["multiscale"]["model"], flow_dict["multiscale"]["result"] = train_flow(create_multiscale_flow(), model_name="MNISTFlow_multiscale")
```

We can show the difference in number of parameters below:

```
[25]:
```

```
def print_num_params(model):
num_params = sum([np.prod(p.shape) for p in jax.tree_leaves(model.state.params)])
print("Number of parameters: {:,}".format(num_params))
print_num_params(flow_dict["simple"]["model"])
print_num_params(flow_dict["vardeq"]["model"])
print_num_params(flow_dict["multiscale"]["model"])
```

```
Number of parameters: 556,312
Number of parameters: 628,388
Number of parameters: 1,711,818
```

Although the multi-scale flow has almost 3 times the parameters of the single scale flow, it is not necessarily more computationally expensive than its counterpart. We will compare the runtime in the following experiments as well.

### Density modeling and sampling¶

Firstly, we can compare the models on their quantitative results. The following table shows all important statistics. The inference time specifies the time needed to determine the probability for a batch of 64 images for each model, and the sampling time the duration it took to sample a batch of 64 images.

```
[26]:
```

```
%%html
<!-- Some HTML code to increase font size in the following table -->
<style>
th {font-size: 120%;}
td {font-size: 120%;}
</style>
```

```
[27]:
```

```
import tabulate
from IPython.display import display, HTML
table = [[key,
"%4.3f bpd" % flow_dict[key]["result"]["val"],
"%4.3f bpd" % flow_dict[key]["result"]["test"],
"%2.1f ms" % (1000 * flow_dict[key]["result"]["time"]),
"%2.1f ms" % (1000 * flow_dict[key]["result"].get("samp_time", 0)),
"{:,}".format(sum([np.prod(p.shape) for p in jax.tree_leaves(flow_dict[key]["model"].state.params)]))]
for key in flow_dict]
display(HTML(tabulate.tabulate(table, tablefmt='html', headers=["Model", "Validation Bpd", "Test Bpd", "Inference time", "Sampling time", "Num Parameters"])))
```

Model | Validation Bpd | Test Bpd | Inference time | Sampling time | Num Parameters |
---|---|---|---|---|---|

simple | 1.080 bpd | 1.078 bpd | 8.5 ms | 9.0 ms | 556,312 |

vardeq | 1.043 bpd | 1.041 bpd | 10.9 ms | 9.2 ms | 628,388 |

multiscale | 1.023 bpd | 1.021 bpd | 7.1 ms | 5.3 ms | 1,711,818 |

As we have intially expected, using variational dequantization improves upon standard dequantization in terms of bits per dimension. Although the difference with 0.04bpd doesn’t seem impressive first, it is a considerably step for generative models (most state-of-the-art models improve upon previous models in a range of 0.02-0.1bpd on CIFAR with three times as high bpd). While it takes longer to evaluate the probability of an image due to the variational dequantization, which also leads to a longer training time, it does not have an effect on the sampling time. This is because inverting variational dequantization is the same as dequantization: finding the next lower integer.

When we compare the two models to multi-scale architecture, we can see that the bits per dimension score again dropped by about 0.02bpd. Additionally, the inference time and sampling time improved notably despite having more parameters. Thus, we see that the multi-scale flow is not only stronger for density modeling, but also more efficient.

Next, we can test the sampling quality of the models. We should note that the samples for variational dequantization and standard dequantization are very similar, and hence we visualize here only the ones for variational dequantization and the multi-scale model. However, feel free to also test out the `"simple"`

model. The seeds are set to obtain reproducable generations and are not cherry picked.

```
[28]:
```

```
sample_rng = random.PRNGKey(44)
samples, _ = flow_dict["vardeq"]["model"].model_bd.sample(img_shape=[16,28,28,1], rng=sample_rng)
show_imgs(samples)
```

```
[29]:
```

```
sample_rng = random.PRNGKey(44)
samples, _ = flow_dict["multiscale"]["model"].model_bd.sample(img_shape=[16,7,7,8], rng=sample_rng)
show_imgs(samples)
```

From the few samples, we can see a clear difference between the simple and the multi-scale model. The single-scale model has only learned local, small correlations while the multi-scale model was able to learn full, global relations that form digits. This show-cases another benefit of the multi-scale model. In contrast to VAEs, the outputs are sharp as normalizing flows can naturally model complex, multi-modal distributions while VAEs have the independent decoder output noise. Nevertheless, the samples from this flow are far from perfect as not all samples show true digits.

### Interpolation in latent space¶

Another popular test for the smoothness of the latent space of generative models is to interpolate between two training examples. As normalizing flows are strictly invertible, we can guarantee that any image is represented in the latent space. We again compare the variational dequantization model with the multi-scale model below.

```
[30]:
```

```
def interpolate(model, rng, img1, img2, num_steps=8):
"""
Inputs:
model - object of ImageFlow class that represents the (trained) flow model
img1, img2 - Image tensors of shape [1, 28, 28]. Images between which should be interpolated.
num_steps - Number of interpolation steps. 8 interpolation steps mean 6 intermediate pictures besides img1 and img2
"""
imgs = np.stack([img1, img2], axis=0)
z, _, rng = model.encode(imgs, rng)
alpha = jnp.linspace(0, 1, num=num_steps).reshape(-1, 1, 1, 1)
interpolations = z[0:1] * alpha + z[1:2] * (1 - alpha)
interp_imgs, _ = model.sample(interpolations.shape[:1] + imgs.shape[1:], rng=rng, z_init=interpolations)
show_imgs(interp_imgs, row_size=8)
exmp_imgs, _ = next(iter(train_exmp_loader))
```

```
[31]:
```

```
sample_rng = random.PRNGKey(42)
for i in range(2):
interpolate(flow_dict["vardeq"]["model"].model_bd, sample_rng, exmp_imgs[2*i], exmp_imgs[2*i+1])
```

```
[32]:
```

```
sample_rng = random.PRNGKey(42)
for i in range(2):
interpolate(flow_dict["multiscale"]["model"].model_bd, sample_rng, exmp_imgs[2*i], exmp_imgs[2*i+1])
```

The interpolations of the multi-scale model result in more realistic digits (first row \(7\leftrightarrow 8\leftrightarrow 6\), second row \(9\leftrightarrow 6\)), while the variational dequantization model focuses on local patterns that globally do not form a digit. For the multi-scale model, we actually did not do the “true” interpolation between the two images as we did not consider the variables that were split along the flow (they have been sampled randomly for all samples). However, as we will see in the next experiment, the early variables do not effect the overall image much.

### Visualization of latents in different levels of multi-scale¶

In the following we will focus more on the multi-scale flow. We want to analyse what information is being stored in the variables split at early layers, and what information for the final variables. For this, we sample 8 images where each of them share the same final latent variables, but differ in the other part of the latent variables. Below we visualize three examples of this:

```
[33]:
```

```
sample_rng = random.PRNGKey(46)
for _ in range(3):
sample_rng, iter_rng = random.split(sample_rng)
z_init = random.normal(sample_rng, [1,7,7,8])
z_init = z_init.repeat(8, 0)
samples, sample_rng = flow_dict["multiscale"]["model"].model_bd.sample(img_shape=z_init.shape, rng=sample_rng, z_init=z_init)
show_imgs(samples)
```

We see that the early split variables indeed have a smaller effect on the image. Still, small differences can be spot when we look carefully at the borders of the digits. For instance, the hole at the left of the 0 changes for different samples although all of them represent the same coarse structure. This shows that the flow indeed learns to separate the higher-level information in the final variables, while the early split ones contain local noise patterns.

### Visualizing Dequantization¶

As a final part of this notebook, we will look at the effect of variational dequantization. We have motivated variational dequantization by the issue of sharp edges/boarders being difficult to model, and a flow would rather prefer smooth, prior-like distributions. To check how what noise distribution \(q(u|x)\) the flows in the variational dequantization module have learned, we can plot a histogram of output values from the dequantization and variational dequantization module.

```
[34]:
```

```
def visualize_dequant_distribution(model, rng, imgs, title):
"""
Inputs:
model - The flow of which we want to visualize the dequantization distribution
imgs - Example training images of which we want to visualize the dequantization distribution
"""
ldj = jnp.zeros(imgs.shape[0], dtype=jnp.float32)
dequant_vals = []
for _ in tqdm(range(8), leave=False):
d, _, rng = model.flows[0](imgs, ldj, rng, reverse=False)
dequant_vals.append(d)
dequant_vals = jnp.concatenate(dequant_vals, axis=0)
dequant_vals = jax.device_get(dequant_vals.reshape(-1))
sns.set()
plt.figure(figsize=(10,3))
plt.hist(dequant_vals, bins=256, color=to_rgb("C0")+(0.5,), edgecolor="C0", density=True)
if title is not None:
plt.title(title)
plt.show()
plt.close()
sample_imgs, _ = next(iter(train_exmp_loader))
```

```
[35]:
```

```
sample_rng = random.PRNGKey(42)
visualize_dequant_distribution(flow_dict["simple"]["model"].model_bd, sample_rng, sample_imgs, title="Dequantization")
```

```
[36]:
```

```
sample_rng = random.PRNGKey(42)
visualize_dequant_distribution(flow_dict["vardeq"]["model"].model_bd, sample_rng, sample_imgs, title="Variational dequantization")
```

The dequantization distribution in the first plot shows that the MNIST images have a strong bias towards 0 (black), and the distribution of them have a sharp border as mentioned before. The variational dequantization module has indeed learned a much smoother distribution with a Gaussian-like curve which can be modeled much better. For the other values, we would need to visualize the distribution \(q(u|x)\) on a deeper level, depending on \(x\). However, as all \(u\)’s interact and depend on each other, we would need to visualize a distribution in 784 dimensions, which is not that intuitive anymore.

## Conclusion¶

In conclusion, we have seen how to implement our own normalizing flow, and what difficulties arise if we want to apply them on images. Dequantization is a crucial step in mapping the discrete images into continuous space to prevent underisable delta-peak solutions. While dequantization creates hypercubes with hard border, variational dequantization allows us to fit a flow much better on the data. This allows us to obtain a lower bits per dimension score, while not affecting the sampling speed. The most common flow element, the coupling layer, is simple to implement, and yet effective. Furthermore, multi-scale architectures help to capture the global image context while allowing us to efficiently scale up the flow. Normalizing flows are an interesting alternative to VAEs as they allow an exact likelihood estimate in continuous space, and we have the guarantee that every possible input \(x\) has a corresponding latent vector \(z\). However, even beyond continuous inputs and images, flows can be applied and allow us to exploit the data structure in latent space, as e.g. on graphs for the task of molecule generation [6]. Recent advances in Neural ODEs allow a flow with infinite number of layers, called Continuous Normalizing Flows, whose potential is yet to fully explore. Overall, normalizing flows are an exciting research area which will continue over the next couple of years.

## References¶

[1] Dinh, L., Sohl-Dickstein, J., and Bengio, S. (2017). “Density estimation using Real NVP,” In: 5th International Conference on Learning Representations, ICLR 2017. Link

[2] Kingma, D. P., and Dhariwal, P. (2018). “Glow: Generative Flow with Invertible 1x1 Convolutions,” In: Advances in Neural Information Processing Systems, vol. 31, pp. 10215–10224. Link

[3] Ho, J., Chen, X., Srinivas, A., Duan, Y., and Abbeel, P. (2019). “Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design,” in Proceedings of the 36th International Conference on Machine Learning, vol. 97, pp. 2722–2730. Link

[4] Durkan, C., Bekasov, A., Murray, I., and Papamakarios, G. (2019). “Neural Spline Flows,” In: Advances in Neural Information Processing Systems, pp. 7509–7520. Link

[5] Hoogeboom, E., Cohen, T. S., and Tomczak, J. M. (2020). “Learning Discrete Distributions by Dequantization,” arXiv preprint arXiv2001.11235v1. Link

[6] Lippe, P., and Gavves, E. (2021). “Categorical Normalizing Flows via Continuous Transformations,” In: International Conference on Learning Representations, ICLR 2021. Link