Part 1.2: Profiling and Scaling Single-GPU Transformer Models

Filled notebook: View filled on Github Open filled In Collab

Author: Phillip Lippe

In the previous part, we have seen how to implement mixed precision training, gradient accumulation, and gradient checkpointing on a simple MLP model. In this part, we will apply these techniques to a transformer model and see how they can help us to train large models with limited resources. We will also see how to profile the model to identify bottlenecks and optimize the performance. It is recommended to go through Part 1.1 before starting this part, as we will be using the same techniques and concepts. We also assume that you are familiar with the transformer model and its components. If you are not, you can refer to the transformer model paper by Vaswani et al. and our transformer tutorial.

This notebook is designed to run on an accelerator, such as a GPU or TPU. If you are running this notebook on Google Colab, you can enable the GPU runtime. You can do this by clicking on Runtime in the top menu, then Change runtime type, and selecting GPU from the Hardware accelerator dropdown. If the runtime fails, feel free to disable the GPU and run the notebook on the CPU. In that case, we recommend to adjust the configuration of the model to fit the available resources.

Prerequisites

To reduce code duplication between notebooks, we import functions from the previous notebook. For this, we have converted the most important functions into a python script and uploaded it to the same repository. If you run on Google Colab, we need to download the python script before importing the functions. If you the notebook locally, it will be already available.

[1]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

The file utils.py contains some simple functionalities, such as setting the XLA flags we have seen in the previous tutorial. Let’s do that first.

[2]:
from utils import install_package, set_XLA_flags_gpu

set_XLA_flags_gpu()

We import our standard libraries below.

[3]:
import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.auto import tqdm

# Install ml_collections on colab
try:
    from ml_collections import ConfigDict
except ModuleNotFoundError:
    install_package("ml_collections")
    from ml_collections import ConfigDict

# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

Finally, we import the functions and modules from our previous tutorial. If you are not familiar with any of these, check out Part 1.1.

[4]:
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics

Building an Optimized Transformer Model

In the following section, we will combine mixed precision, gradient checkpointing and gradient accumulation to train a larger Transformer model on a single GPU.

Model Definition

For passing hyperparameters and configurations to our modules, we will make use of ml-collections’ ConfigDict class (docs). A config dict is a dict-like data structure that supports dot access to its keys, and provides a ‘frozen’ version which is useful for JAX.

We start with implementing the MLP layer in the Transformer model. We support mixed precision from before.

[5]:
class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        x = nn.Dense(
            features=self.config.mlp_expansion * input_features,
            dtype=self.config.dtype,
            name="input_layer",
        )(x)
        x = nn.gelu(x)
        x = nn.Dense(
            features=input_features,
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        return x

Next, we turn to the attention block. To support mixed precision with numerical stability, we cast the attention weights to float32 before the softmax operation, as discussed before. In cases where we would use float16 precision, the dot product could occasionally go out of range, leading to numerical instability (see e.g. Karras et al., 2023). Thus, we cast the query and key tensors to float32 before the softmax operation, and cast the attention weights back to bfloat16 after the softmax operation. Alternatively, one could also keep the query and key tensors in bfloat16 if we are just short of GPU memory. We implement the adjusted dot product attention below:

[6]:
def dot_product_attention(
    query: jax.Array,
    key: jax.Array,
    value: jax.Array,
    mask: jax.Array | None,
    softmax_dtype: jnp.dtype = jnp.float32,
):
    """Dot-product attention.

    Follows the setup of https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.dot_product_attention,
    but supports switch to float32 for numerical stability during softmax.

    Args:
        query: The query array, shape [..., num queries, num heads, hidden size].
        key: The key array, shape [..., num keys, num heads, hidden size].
        value: The value array, shape [..., num keys, num heads, hidden size].
        mask: The boolean mask array (0 for masked values, 1 for non-masked). If None, no masking is applied.
        softmax_dtype: The dtype to use for the softmax and dot-product operation.

    Returns:
        The attention output array, shape [..., num queries, num heads, hidden size].
    """
    num_features = query.shape[-1]
    dtype = query.dtype
    scale = num_features**-0.5
    query = query * scale
    # Switch dtype right before the dot-product for numerical stability.
    query = query.astype(softmax_dtype)
    key = key.astype(softmax_dtype)
    weights = jnp.einsum("...qhd,...khd->...hqk", query, key)
    if mask is not None:
        weights = jnp.where(mask, weights, jnp.finfo(softmax_dtype).min)
    weights = nn.softmax(weights, axis=-1)
    # After softmax, switch back to the original dtype
    weights = weights.astype(dtype)
    new_vals = jnp.einsum("...hqk,...khd->...qhd", weights, value)
    new_vals = new_vals.astype(dtype)
    return new_vals

With that, we can implement the attention block below. We use nn.DenseGeneral to implement the linear projections. Depending on the size of the hidden size, it may be beneficial to split the query, key and value projections into multiple smaller projections, also to give the XLA compiler more flexibility to schedule the computation. For simplicity, we use a single layer projection here, which is commonly more efficient for small model sizes.

[7]:
class AttentionBlock(nn.Module):
    config: ConfigDict
    mask: jax.Array | None
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        qkv = nn.DenseGeneral(
            features=(self.config.num_heads, self.config.head_dim * 3),
            dtype=self.config.dtype,
            name="qkv",
        )(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)
        x = dot_product_attention(q, k, v, mask=self.mask, softmax_dtype=self.config.softmax_dtype)
        x = nn.DenseGeneral(
            features=input_features,
            axis=(-2, -1),
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        return x

We can now combine the two blocks to implement a full Transformer block. In this block, we want to support gradient checkpointing around the two individual blocks. For this, we consider the config to have a remat key, which contains a sequence of names, indicating the functions/modules to remat. We implement the Transformer block below:

[8]:
class TransformerBlock(nn.Module):
    config: ConfigDict
    mask: jax.Array | None
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        # MLP block
        mlp = MLPBlock
        if "MLP" in self.config.remat:
            mlp = nn.remat(mlp, prevent_cse=False)
        x = x + mlp(config=self.config, train=self.train, name="mlp")(x)
        # Attention block
        attn = AttentionBlock
        if "Attn" in self.config.remat:
            attn = nn.remat(attn, prevent_cse=False)
        x = x + attn(config=self.config, mask=self.mask, train=self.train, name="attn")(x)
        return x

With that, we are ready to implement the full Transformer model. We use the scan transformation to scan over the layers of the model to reduce the compilation time. We implement a text-based GPT-style autoregressive model, which uses an embedding layer to embed the input tokens, and a stack of Transformer blocks to process the tokens. We also add a final dense layer to map the output tokens to the vocabulary size. We implement the model below:

[9]:
class Transformer(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(
        self, x: jax.Array, mask: jax.Array | None = None, train: bool = True
    ) -> jax.Array:
        if mask is None and self.config.causal_mask:
            mask = nn.make_causal_mask(x, dtype=jnp.bool_)
        # Input layer.
        x = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="embed",
        )(x)
        pos_emb = self.param(
            "pos_emb",
            nn.initializers.normal(stddev=0.02),
            (self.config.max_seq_len, self.config.hidden_size),
        )
        pos_emb = pos_emb.astype(self.config.dtype)
        x = x + pos_emb[None, : x.shape[1]]
        # Transformer blocks.
        block_fn = functools.partial(TransformerBlock, config=self.config, mask=mask, train=train)
        if "Block" in self.config.remat:
            block_fn = nn.remat(block_fn, prevent_cse=False)
        if self.config.scan_layers:
            block = block_fn(name="block")
            x, _ = nn.scan(
                lambda module, carry, _: (module(carry), None),
                variable_axes={"params": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.num_layers,
            )(block, x, ())
        else:
            for l_idx in range(self.config.num_layers):
                x = block_fn(name=f"block_{l_idx}")(x)
        # Output layer.
        x = nn.LayerNorm(dtype=self.config.dtype, name="post_norm")(x)
        x = nn.Dense(
            features=self.config.num_outputs,
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = x.astype(jnp.float32)
        return x

Initialization

With the model set up, we can continue with the initialization. The initialization process is as usual, besides that we create a more detailed config dict below to specify all hyperparameters in the model. By default, we run with bfloat16 precision and remat the MLP and Attention block. The model has 12 layers with a hidden size of 1024. We also create a config for the data, which we will use to create the example batch. We create batches with 64k tokens, which is large for a single GPU, but language models often train with ~1M tokens per batch. Feel free to change the hyperparameters to see how the model behaves with different settings.

[10]:
data_config = ConfigDict(
    dict(
        batch_size=64,
        seq_len=512,
        vocab_size=2048,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=1024,
        dropout_rate=0.1,
        mlp_expansion=4,
        num_layers=12,
        head_dim=128,
        causal_mask=True,
        max_seq_len=data_config.seq_len,
        vocab_size=data_config.vocab_size,
        num_outputs=data_config.vocab_size,
        dtype=jnp.bfloat16,
        softmax_dtype=jnp.float32,
        scan_layers=True,
        remat=("MLP", "Attn"),
    )
)
model_config.num_heads = model_config.hidden_size // model_config.head_dim
optimizer_config = ConfigDict(
    dict(
        learning_rate=4e-4,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        seed=42,
    )
)

We now create the model and initialize the parameters. We set the optimizer to be Adam with a warmup exponential decay schedule, although the optimizer is not really relevant for the simple example at hand.

[11]:
model = Transformer(config=config.model)
optimizer = optax.adam(
    learning_rate=optax.warmup_exponential_decay_schedule(
        init_value=0,
        peak_value=config.optimizer.learning_rate,
        warmup_steps=10,
        transition_steps=1,
        decay_rate=0.99,
    )
)

We train the model again on a single example batch. Since we perform autoregressive language modeling as the task, the input are the tokens shifted by one, and the target are the original tokens. We also use a causal mask, specified in the config, to prevent the model from attending to future tokens.

[12]:
tokens = jax.random.randint(
    jax.random.PRNGKey(0),
    (config.data.batch_size, config.data.seq_len),
    1,
    config.data.vocab_size,
)
batch_transformer = Batch(
    inputs=jnp.pad(tokens[:, :-1], ((0, 0), (1, 0)), constant_values=0),
    labels=tokens,
)

Finally, we initialize the parameters of the model, in the same way as before.

[13]:
model_rng, state_rng = jax.random.split(jax.random.PRNGKey(config.seed))
params = model.init(
    model_rng,
    batch_transformer.inputs[: config.data.batch_size // config.optimizer.num_minibatches],
    train=False,
)["params"]
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    rng=state_rng,
)

Let’s check the number of parameters below.

[14]:
def get_num_params(state: TrainState) -> int:
    return sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params))


print(f"Number of parameters: {get_num_params(state):_}")
Number of parameters: 155_877_376

With 150M parameters, the model is still relatively small compared to today’s language models, but still challenging to fit on a single GPU. Furthermore, with a batch size of 64k tokens, the memory consumption of the activations is already significant.

Training

We can now train the model with gradient accumulation. We set the number of gradient accumulation steps to 4, which means that we accumulate the gradients over 4 sub-batches. We first define a loss function, which is very similar to the classification loss we have seen before, adjusted to allow for sequences.

[15]:
def next_token_pred_loss(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
    """Next token prediction loss function."""
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = np.prod(batch.labels.shape)
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

We also adjust the train step to use the new loss function. Everything else remains unchanged in the train step.

[16]:
@functools.partial(
    jax.jit,
    donate_argnames=(
        "state",
        "metrics",
    ),
)
def train_step_transformer(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    """Training step function.

    Executes a full training step with gradient accumulation for the next-token prediction task.

    Args:
        state: Current training state.
        metrics: Current metrics, accumulated from previous training steps.
        batch: Training batch.

    Returns:
        Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
    """
    # Split the random number generator for the current step.
    rng, step_rng = jax.random.split(state.rng)
    # Determine gradients and metrics for the full batch.
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=next_token_pred_loss,
        use_scan=True,
    )
    # Optimizer step.
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Accumulate metrics across training steps.
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

We now determine the metric shapes and initialize the metric PyTree, as we did before.

[17]:
_, metric_shapes = jax.eval_shape(
    train_step_transformer,
    state,
    None,
    batch_transformer,
)
metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

Now, we can finally train the model. The goal of the training is not to show the model’s performance, but to show the impact of the different techniques on the memory footprint and training speed. Feel free to experiment with different hyperparameters to see how the model behaves with different settings.

[18]:
for _ in tqdm(range(4)):
    state, metrics = train_step_transformer(state, metrics, batch_transformer)
final_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state, final_metrics = train_step_transformer(state, final_metrics, batch_transformer)
print_metrics(final_metrics, "Final metrics - Transformer")
 Final metrics - Transformer
accuracy: 0.000916
loss: 7.776346

Profiling

To gain further insights into the model execution and see the individual operations, we can profile the model (documentation). In JAX, profiling the model creates a trace file which we can view in tools like Chrome’s Trace Viewer or TensorBoard. We can start the profiling via jax.profiler.start_trace, and stop it with jax.profiler.stop_trace. Alternatively, one can use a context manager to start and stop the profiling. For the profiling, we run three training steps to get a good overview of the model execution and reduce the potential impact of the profiler on the model execution. Further, we can annotate operations in the trace via jax.profiler.StepTraceAnnotation or jax.named_scope, to better understand the model execution. Finally, before stopping the trace, we wait for the last train step to finish by blocking the execution until the metrics are ready. We implement the profiling below:

[19]:
jax.profiler.start_trace("traces/")
for i in range(3):
    with jax.profiler.StepTraceAnnotation("train_step", step_num=i + 1):
        state, metrics = train_step_transformer(state, metrics, batch_transformer)
metrics["loss"][0].block_until_ready()
jax.profiler.stop_trace()

With the trace generated, we can now visualize the model execution in TensorBoard. For this, we switch to the tab Profiler and load the newest trace file. Under trace_viewer@, we can see the individual operations and their execution time. Additionally, we can inspect the used memory in the memory_viewer tab (select jit_train_step_transformer under modules). The cell below is commented out, as it may take a while to start the TensorBoard, but feel free to run it to inspect the trace on your local machine.

[20]:
# %load_ext tensorboard
# %tensorboard --logdir traces/single_gpu_transformer

Since the trace will be different for different hyperparameters and different hardware configurations, we have uploaded some example runs here. Feel free to download them and investigate the models yourself. All experiments were run on a single A5000 GPU, which has up to 24GB of memory. Below, we go through some example traces to show the impact of the individual techniques on the model execution, and explain how to read the profiler output.

Profiler Overview

The profiler in TensorBoard is a powerful tool to find understand your model execution and find bottlenecks. For a full overview of the profiler, we recommend the official documentation. Here, we give a brief overview of the most important tabs and how to read the profiler output.

Trace Viewer

The trace viewer is the main tab to inspect the model execution. It shows the individual operations and their execution time. The operations are grouped by the JAX transformation, such as jit, vmap, or scan. We can inspect the execution time of the individual operations, and see which operations take the most time. This can help us to identify potential bottlenecks in the model execution, and optimize the model accordingly. Below is an example view of the trace viewer:

e4d74daa56a9481eb297687d89efc3c9

On the left, you have tabs to select the run and hosts if you have multiple nodes. In the middle, you have the individual operations. Here, we are mainly looking at the TensorFlow Name Scope which shows operations with their annotated names (and are most easily understandable for us). On the right, you see the toolbar. The single cursor allows you to select individual blocks and see more details on them (wall clock duration, start time, etc.). The four-way arrow allows you to move around the trace. The up-down arrow allows to zoom into the trace by clicking and dragging up (zoom in) or dragging down (zoom out). This helps us to focus on specific parts of the trace and get down to the individual operations. The left-right arrow allows us to select a subset of the trace and measure the time from one to the other operation. This is helpful for finding the joint execution time of multiple operations together. Overall, in this view, we can see the individual operations and their execution time, and identify potential bottlenecks in the model execution.

Memory Viewer

The memory viewer shows the memory consumption of the model. It shows the memory consumption over operations during the model execution, and how the memory consumption changes over time. This can help us to identify potential memory bottlenecks in the model execution, and optimize the model accordingly. Below is an example view of the memory viewer:

febf59b0e10c4f8184c48b1a29c569e4

You can hover over the memory graph to find the memory consumption at a specific point in time. Further, on the bottom, you can find the individual arrays that make up the memory consumption. This is very helpful to find the largest memory consumers, and check whether your arrays are all in the right precision and we didn’t forget somewhere to cast them to bfloat16. Overall, in this view, we can see the memory consumption of the model and identify potential memory bottlenecks in the model execution.

We will use both views to understand the impact of the individual techniques on the model execution.

Mixed Precision Training

First, we compare a model in float32 versus bfloat16 precision. For this, we adjust above’s config to remove all remats and set the batch size to 64, to fit in memory. We then profile the model with float32 and bfloat16 precision. In the trace, we look at the memory viewer to get an idea of the memory usage:

48214fb221af465a86cbfb3668693826

The float32 model is at the maximum of the GPU memory with 20.6GB, while we also already see warnings of JAX that is had to perform automatic rematting. This is a sign that the model is too large to fit into memory. We can further investigate the arrays that take up the most memory in the view below the memory trace.

74f1ff13d9cf4e7e9c38418414726665

The arrays with largest memory usage are of shape [12, 16, 512, 4096], which are the activations within the MLP block (12 layers, minibatch size 16, 512 sequence length, 4096 hidden size). We can also see that the activations are in float32 precision, which is the main reason for the large memory consumption.

We can now compare this to the bfloat16 model. The memory trace of the bfloat16 model is shown below:

fb10ae4c4beb4d2a99b50336004125b1

The bfloat16 model is at 14.6GB, which is significantly less than the float32 model. We can also see that the activations are in bfloat16 precision, which is the main reason for the reduced memory consumption. Further, when looking at the largest arrays again, we see that most activations are in bfloat16 precision, and previously largest arrays of shape [12, 16, 512, 4096] are now only half the size in memory (768MB).

9020ba2af9974e33b68b09ef156a16a3

The largest array remaining are the softmax logits in the attention, which with shape [12, 16, 8, 512, 512] are 1.5GB (12 layers, minibatch size 16, 8 attention heads, 512 sequence length). This remains in float32 to prevent numerical instabilities. Overall, this comparison shows the potential of mixed precision training to reduce the memory footprint of the model.

Memory is not the only aspect mixed precision improves. If we look at the trace_viewer tab, we can see that the execution time of the model is also significantly reduced. The float32 precision model takes 2.1 seconds per training step (see wall duration in the picture below). Note that this training step consists of 4 minibatch steps, which we can see in the 4 jvp and transpose blocks per train step.

762cdfa48ff84d7c8e20fc1d34d62a5c

The bfloat16 precision model only takes 1.1 seconds per training step, which is a significant reduction in training time. Each operation can take advantage of the bfloat16 supports of the GPUs tensor cores, which allows for the significant speed up. This shows the potential of mixed precision training to reduce the training time of the model, as well as the memory footprint.

b6dea9285e974e00af67b469c4a80bea

Scanning Layers

Before we continue with the other techniques, we take a closer look at the trace to identify potential model inefficiencies. For this, we zoom in to the trace_viewer tab and look at the individual operations. We see the operation within the block (e.g. mlp and attn), but also that there is quite some gap between the execution of subsequent layers. At closer inspection, many of these gaps are due to the reoccuring operation dynamic_update_slice:

d87d4e4a47f8414995c9947d2f9b9fa1

This operation is used to copy one array into another, and is often used in the scan transformation to update the global state of the loop with the buffers of the individual layers. However, we can see that this operation is quite slow since we have to copy large arrays within the GPU memory compared to a fast layer execution. This is a sign that the scan transformation is not optimal for the model, and we should consider sacrificing some compilation time for a more efficient model execution, especially since the model is not extremely deep.

Hence, we test our model with scan_layers=False. While the compilation time increases, it stays within a few seconds, which is negligible for the overall training time. We show the trace of the new model below. We can see that the execution time of the model is significantly reduced to 0.73 seconds instead of 1.1 seconds, and the dynamic_update_slice operations are gone.

76c261ed144e4cf4ab479e353946b972

Furthermore, the peak memory is also reduced to 8.8GB instead of 14.6GB, which is a significant reduction in memory consumption. This is because we do not enforce the model anymore to keep the full activations of all layers in memory and can release the memory of a layer as soon as the gradients have been calculated:

c4d38bb4881f47b7bb887205f947f528

As a result, we find many more small arrays in our buffer, which are the activations of the individual layers. While this can give the compiler more freedom to schedule the computation, we may suffer more from memory fragmentation. However, for the model at hand, this is not a significant issue and we find a significant reduction in memory consumption and execution time when not scanning the layers.

This insight should not be taken as a general rule, but as a reminder to always profile the model and consider the trade-offs of different techniques. For larger models, the scan transformation can be beneficial to reduce the compilation time, but for smaller models, it can be more beneficial to not scan the layers to reduce the memory consumption and execution time.

Gradient Checkpointing

Another situation where scanning the layers become efficient again is when we combine it with gradient checkpointing. When recomputing most activations, we reduce the memory that needs to be kept between loop iterations in the scan and thus significantly reduce the dynamic slice operations. For instance, we trace a model using scan and config.remat=("MLP", "Attn"). This corresponds to checkpointing the input activations of the MLP and Attention Block, but recomputing the inner activations of both blocks. We show the trace below:

0a85ed1bd2c14f63ab357ada727378a7

The model takes 0.91 seconds per training step, which is 25% slower than the model without scanning and rematting. Still, the model execution is faster than the scanned model without rematting, since we reduce the memory that needs to be kept between loop iterations. In the trace, the dynamic slice operations take a negligible amount of time now. To also verify that the model is performing the gradient checkpointing as intended, we can zoom into the backward pass of the model. There, we see that in each block, the model is performing rematted_computation blocks, which corresponds to recomputing the activations during the backward pass:

9bf05484bdfc4124b7d3939bf218af4c

Let’s also check the memory consumption of the model, since this is the main goal of gradient checkpointing. The peak memory, shown below, is reduced to only 3.9GB, which is significantly less than the 14.6GB of the model without rematting.

263c495babd84b73b6ba5ff0c2c2d3b8

Furthermore, the largest array left in the buffer is the MLP parameters of the model. This indicates that we can significantly increase the model size and batch size with gradient checkpointing, which we could not do with the model without rematting.

79abb347c66640999ffb203bda1c3e61

Besides rematting the MLP and Attention block, we could also remat the full block. However, since the activations are not the limiting factor for the memory consumption anymore, there is no significant benefit in rematting the full block. We find the model to use 3.8GB of memory, which is only slightly less than the model with rematting the MLP and Attention block. Further, the execution time is also slightly slower with 0.96 seconds per training step, which is likely not worth the small reduction in memory consumption in our case.

Nonetheless, these experiments show the potential of gradient checkpointing to reduce the memory footprint of the model and allow for larger models and batch sizes.

Gradient Accumulation

With mixed precision and gradient checkpointing, we saved so much memory that we do not need gradient accumulation anymore. To check this, we run a model with bfloat16, remat=("MLP","Attn"), and set the number of minibatches to 1, i.e. no gradient accumulation. We first show the memory consumption below:

9290353f44e64c1d949259ecfcb54230

The model takes 6.2GB of memory, which is an increase of the gradient accumulation model, but still significantly less than the maximum GPU memory of 24GB. Further, we can check the execution time of the model by looking at the trace:

28ed797300d249cb9d7aeba904b4f173

With 0.86 seconds per training step, the model is slightly faster than the model with gradient accumulation. This is because the model can parallelize operations better and utilize the GPU more efficiently. Hence, we may want to reduce the usage of gradient accumulation if we have the GPU memory to fit the full batch into it.

Furthermore, we can scale the batch size well beyond 64. For instance, a batch size of 256 fits well into the memory (15GB usage), while the initial model hit the memory limit with a minibatch size of 16. This shows the potential of the combined techniques to reduce the memory footprint of the model and allow for larger models and batch sizes even on a single GPU.

Conclusion

In this notebook, we have explored several techniques to train larger models on a single device. We have implemented mixed precision training, gradient accumulation, and gradient checkpointing on a simple MLP model, and discussed JAX-specific structures to reduce the memory footprint of the model. We have also trained a larger Transformer model with these techniques and profiled the model to gain further insights into the model execution. We have seen that these techniques can significantly reduce the memory footprint of the model and help training larger models. However, these techniques also come with trade-offs, such as increased training time and reduced numerical precision. It is important to carefully consider these trade-offs when training larger models, and to experiment with different techniques to find the best setup for the specific model and hardware configuration. We have also seen that JAX provides a powerful backend with the XLA compiler to optimize our computations on the available hardware, and that we can use the profiler to gain further insights into the model execution. We hope that this notebook has provided a good overview of the techniques to train larger models on a single GPU, and has given a good starting point for further exploration of these techniques. In the following notebooks, we will explore how to train larger models on multiple GPUs and TPUs, and discuss the different parallelization strategies to scale the training to multiple devices.

References and Resources

[Bulatov, 2018] Bulatov, Y., 2018. Fitting larger networks into memory. Blog post link

[Kalamkar et al., 2019] Kalamkar, D., Mudigere, D., Mellempudi, N., Das, D., Banerjee, K., Avancha, S., Vooturi, D.T., Jammalamadaka, N., Huang, J., Yuen, H. and Yang, J., 2019. A study of BFLOAT16 for deep learning training. arXiv preprint arXiv:1905.12322. Paper link

[Ahmed et al., 2022] Ahmed, S., Sarofeen, C., Ruberry, M., et al., 2022. What Every User Should Know About Mixed Precision Training in PyTorch. Tutorial link

[Raschka, 2023] Raschka, S., 2023. Optimizing Memory Usage for Training LLMs and Vision Transformers in PyTorch. Tutorial link (gives more details for the topics here in PyTorch)

[HuggingFace, 2024] HuggingFace, 2024. Performance and Scalability: How To Fit a Bigger Model and Train It Faster. Tutorial link

[NVIDIA, 2024] NVIDIA, 2024. Mixed Precision Training. Documentation link

[NVIDIA, 2024] NVIDIA, 2024. Performance Guide for Training. Documentation link

[Google, 2024] JAX Team Google, 2024. Control autodiff’s saved values with jax.checkpoint (aka jax.remat). Tutorial link

[Google, 2024] JAX Team Google, 2024. Profiling JAX programs. Tutorial link

[Google, 2024] JAX Team Google, 2024. GPU peformance tips. Tutorial link


Star our repository If you found this tutorial helpful, consider ⭐-ing our repository.
Ask questions For any questions, typos, or bugs that you found, please raise an issue on GitHub.