Part 3.1: Pipeline Parallelism

Filled notebook: View filled on Github Open filled In Collab

Author: Phillip Lippe

In the previous tutorial, we have seen data parallelism, which works well for intermediate large models and large batch sizes. However, as we continue to increase the model size, the batch size per device can become very small, which can lead to inefficient usage of the accelerators. This is because on the one hand, the time spent on communication between devices can become significant compared to the time spent on computation. On the other hand, small batch sizes can lead to poor utilization of the device’s computational resources, since fewer operations can be performed in parallel.

In such cases, we need to consider parallelism strategies that can parallelize the model itself, rather than the data. Such strategies can lead to strong utilization, even if the batch size per device is small. In this notebook series, we will consider two such strategies: pipeline parallelism and tensor parallelism. Intuitively, pipeline parallelism splits the model across layers, while tensor parallelism splits the model across feature dimensions. We visualize these strategies in the figure below.


In this notebook, we will focus on pipeline parallelism. We will first discuss the concept of pipeline parallelism, and then show how to implement it in JAX.

Pipeline parallelism is a way to parallelize the forward and backward passes of a model across multiple devices. In comparison to data parallelism which replicates the model across devices, the model is instead split across its layers into multiple stages. Each stage consists of multiple layers of the model, and is placed on a different device. The output of each stage is passed to the next stage, and the final output is the result of the last stage. For example, consider a Transformer model with 12 layers and 4 devices. In this case, we can split the model into 4 stages, each consisting of 3 layers. The first three layers are placed on the first device, the next three on the second device, and so on. Given an input batch, we start by passing it to stage 1 on the first device. The output of stage 1 is then passed to stage 2, and so on, until the final output is produced by stage 4. The backward pass is performed in the reverse order, starting from the last stage and ending at the first stage. This way, each device only requires a subset of the model, reducing the memory requirements and allowing for larger models to be trained. At the same time, we introduce minimal communication between devices, as the output of each stage is only passed to the next stage. This leads to a computation graph similar as in the figure below (\(F_i\) - forward pass of stage \(i\), \(B_i\) - backward pass of stage \(i\), Update - optimizer step and gradient communication). We note that the precise timings may depend on additional factors, such as communication speed, cost of backward pass, additional data parallelism (especially for update step), etc. But for now, let’s focus on the basic idea.


When looking at the figure, we can see that the forward pass of stage 1 can start as soon as the first layer is computed, and the backward pass of stage 4 can start as soon as the last layer is computed. However, for a large amount of time, the devices are idle, as they are waiting for the output of the previous stage. This can lead to the “pipeline bubble” problem, where the utilization of the devices is reduced due to the time spent waiting for the output of the previous stage. In this notebook, we will discuss and implement simple strategies to mitigate the pipeline bubble problem, and show how this can be done efficiently in JAX.


First, let’s start with setting up the basic environment and utility functions we have seen from previous notebooks. We download the python scripts of the previous notebooks below. This is only needed when running on Google Colab, and local execution will skip this step automatically.

import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = ""
# Files to download.
python_files = ["", "", ""]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",

As before, we simulate 8 devices on CPU to demonstrate the parallelism without the need for multiple GPUs or TPUs. If you are running on your local machine and have multiple GPUs available, you can comment out the lines below.

from utils import simulate_CPU_devices


We now import our standard libraries.

import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core.frozen_dict import FrozenDict
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

# Helper types
PyTree = Any
Parameter = jax.Array | nn.Partitioned
Metrics = Dict[str, Tuple[jax.Array, ...]]

We also import the utility functions from the previous notebooks.

from data_parallel import fold_rng_over_axis, sync_gradients
from single_gpu import (

Pipeline Parallelism with Micro-Batching

The first strategy to mitigate the pipeline bubble problem is to use micro-batching, as introduced in GPipe [Huang et al., 2019]. The idea is to split the input batch into smaller sub-batches (micro-batches), and processing them sequentially. At the end of each micro-batch, we communicate the outputs between stages, and start processing the next micro-batch. This way, we can keep the devices busy while waiting for the output of the previous stage, and reduce the pipeline bubble problem.

For example, consider a batch of size 32, and a pipeline with 4 stages. We can split the batch into 4 micro-batches (or any other factor of 32), each of size 8, and process them sequentially. As soon as the first micro-batch is processed by stage 1, we can communicate the output to stage 2, and start processing the second micro-batch, as so on. The figure below shows the computation graph for the forward and backward passes of the pipeline with this micro-batching strategy.


Compared to the original pipeline, we can see that the devices are kept busy for a larger portion of the time, as they are processing the micro-batches sequentially. However, we also note that the communication between stages is now more frequent, as we need to communicate the output of each micro-batch. This can lead to increased communication overhead, especially for small micro-batches. In practice, the choice of the micro-batch size is a trade-off between the pipeline bubble problem, the communication overhead, and the max utilization we can achieve per device with this micro-batch size.

In the following, we will show how to implement pipeline parallelism with micro-batching in JAX. We will use a simple MLP model for demonstration (to make it feasible on CPU), and show how to split the model across stages. For any other model such as larger Transformer, the same principles apply and can be implemented without changes to the pipeline wrapper below.

Module Preparation

We start with implementing a simple MLP model, which requires no changes compared to a non-pipeline model. The MLP consists of multiple residual blocks as in a Transformer model, just without the attention (for simplicity). As in the previous notebook, we use a ConfigDict to store the model hyperparameters, a scan for reducing compilation time, and support remat and mixed precision.

class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        residual = x
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        x = nn.Dense(
            features=self.config.hidden_size * self.config.mlp_expansion,
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        x = nn.Dense(features=input_features, dtype=self.config.dtype, name="output_dense")(x)
        return x + residual
class MLPLayers(nn.Module):
    config: ConfigDict
    train: bool

    def __call__(self, x: jax.Array) -> jax.Array:
        # Scan version
        block_class = MLPBlock
        if "MLP" in self.config.remat:
            block_class = nn.remat(block_class, prevent_cse=False)
        block = block_class(config=self.config, train=self.train, name="block")
        x, _ = nn.scan(
            lambda module, carry, _: (module(carry), ()),
            variable_axes={"params": 0},
            split_rngs={"params": True, "dropout": True},
        )(block, x, ())
        # Non-scanned version
        # for i in range(self.config.num_layers):
        #     x = block_class(self.config, train=train, name=f"block_{i}")(x)
        return x

Similar to the previous notebook on data parallelism, we have parameters have different values on different devices. Note that we will be generally allowing for multiple axes in our mesh, since model parallelism is often combined with data parallelism. Thus, we will wrap the parameters in a nn.Partitioned class to annotate their sharding over the model axis (for simplicity, we do not consider FSDP here, but show how it can be easily added in a later notebook). This way, we can easily split the parameters across devices, and use the same model definition for all stages.

Compared to last time, we stack the parameters over a new axis that we create. While one could also concatenate them along the layer index that is introduced by nn.scan, it would be unintuitive and error-prone for settings or parameters that are not scanned over. Still, this is only a design choice and both ways can be used.

Below, we implement functions that will support the partitioning of parameters in this fashion via the nn.map_variables transform.

def stack_params(
    params: PyTree, axis_name: str, axis: int = 0, mask_except: jax.Array | int | None = None
) -> PyTree:
    """Stacks sharded parameters along a given axis name.

        params: PyTree of parameters.
        axis_name: Name of the axis to stack along.
        axis: Index of the axis to stack along.
        mask_except: If not None, only the `mask_except`-th shard will be non-zero.

        PyTree of parameters with the same structure as `params`, but with the leaf
        nodes replaced by `nn.Partitioned` objects with sharding over axis name added
        to `axis`-th axis of parameters.

    def _stack(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned):
            value, names = x.value, x.names
            value, names = x, (None,) * x.ndim
        if mask_except is not None:
            axis_index = jax.lax.axis_index(axis_name)
            value = jnp.where(axis_index == mask_except, value, 0.0)
        value = jnp.expand_dims(value, axis)
        names = names[:axis] + (axis_name,) + names[axis:]
        return nn.Partitioned(value, names=names)

    return jax.tree_map(_stack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))

def unstack_params(params: PyTree, axis_name: str) -> PyTree:
    """Unstacks parameters along a given axis name.

    Inverse operation to `stack_params`.

        params: PyTree of parameters.
        axis_name: Name of the axis to unstack along.

        PyTree of parameters with the same structure as `params`, but
        with the leaf nodes having the sharding over the axis name removed.

    def _unstack(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned) and axis_name in x.names:
            value = x.value
            names = x.names
            axis_idx = names.index(axis_name)
            value = value.squeeze(axis_idx)
            names = names[:axis_idx] + names[axis_idx + 1 :]
            if all([n is None for n in names]):
                return value
                return nn.Partitioned(value, names=names)
            return x

    return jax.tree_map(_unstack, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))

Pipeline Implementation

The pipeline implementation is iterating over two simple steps: applying the stages per device on a micro-batch, and communicating the last output between devices to get the next input. Whenever there is no proper input ready, we still run the stage to follow the SPMD principle (single program multiple devices), but ignore the output. Our communication is performed using jax.lax.ppermute which transfers an array over a ring: stage 1 sends its output to stage 2, stage 2 to stage 3, stage 3 to stage 4, and finally, stage 4 to stage 1. The last communication between stage 4 and stage 1 can be ignored, as it is only used to close the ring and stage 1 uses the original micro-batches as input. Instead, we store the last \(N\) outputs of stage 4 and use them for the loss calculation in the final layers, since those are the outputs of the full model. In summary, the computation graph for the forward pass will look similar as in the figure below:


Note that for simplicity, we didn’t visualize unused communications (e.g. stage 4 to stage 1). Finally, the backward pass will be automatically handled by JAX’s jax.grad transformation.

We implement this loop with a nn.scan operation, which keeps the parameters across all steps the same, but allows for different inputs and outputs, as well as updating the RNGs (used in Dropout). First, we implement a single step of the loop below:

def execute_pipeline_step(
    module: nn.Module,
    state: jax.Array,
    input: jax.Array,
    model_axis_name: str,
) -> Tuple[jax.Array, jax.Array]:
    """Single micro-batch pipeline step.

        module: Flax module representing the stage to execute.
        state: Last communicated features between stages. Used as input to the module for all stages except the first.
        input: Original micro-batch input to the pipeline stage. Used as input to the module for the first stage.
        *args: Additional arguments to the module.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

        Tuple of the new state (after communication) and the output of the module.
    num_stages = jax.lax.psum(1, model_axis_name)
    stage_index = jax.lax.axis_index(model_axis_name)
    # For the first stage, we use the microbatches as input.
    # For all other stages, we use the last state from the
    # previous stage as input.
    state = jnp.where(stage_index == 0, input, state)
    state = module(state, *args, **kwargs)
    # For the last stage, we return the state as output.
    # For all other stages, we return zeros.
    output = jnp.where(
        stage_index == num_stages - 1,
    # Communicate the last state to the next stage.
    state = jax.lax.ppermute(
        perm=[(i, (i + 1) % num_stages) for i in range(num_stages)],
    return (state, output)

With the single step implemented, we can now wrap it in a nn.scan operation to iterate over the micro-batches. Compared to the scan over layers in the MLPLayers module, we now scan over the micro-batches, and keep the parameters the same. The latter is controlled by setting variable_broadcast to {"params": True}, and not splitting the params RNG over iterations. To scan over the input and output, we add in_axes=0 and out_axes=0, which effectively unstacks the input and stacks the output over the first axis across iterations. In other words, the first iteration get the first micro-batch as input, the second iteration the second micro-batch, and so on. Additionally, the current state of the stages is communicated as a carry, which is passed between iterations.

The final difference in the nn.scan is that we do not want to scan a single module, but instead actually the function execute_pipeline_step. Flax allows for that by requiring the function to take a module as its first argument. This module is then scanned as specified by the other keyword arguments, but we execute the passed function at each iteration. This way, we can use our execute_pipeline_step function with the nn.scan operation. This results in the following pipeline function:

@jax.named_scope("pipeline")  # Naming scope for profiling.
def execute_pipeline(
    module: nn.Module, x: jax.Array, *args, num_microbatches: int, model_axis_name: str, **kwargs
) -> jax.Array:
    """Execute a pipeline of stages on a batch of data.

    Uses the principle of GPipe in splitting the batch into micro-batches
    and running the pipeline stages in parallel.

        module: Flax module representing the pipeline stage to execute.
        x: Batch of input data, only needed on device of the first stage. Data will be split into micro-batches.
        *args: Additional arguments to the module.
        num_microbatches: Number of micro-batches to split the batch into.
        model_axis_name: Name of the model axis in the mesh/shard_map.
        **kwargs: Additional keyword arguments to the module.

        Output of the last stage of the pipeline. For devices that are not
        the last stage, the output is zeros.
    num_stages = jax.lax.psum(1, model_axis_name)
    # Structure the input data into micro-batches.
    batch_size = x.shape[0]
    assert (
        batch_size % num_microbatches == 0
    ), f"Batch size {batch_size} must be divisible by number of microbatches {num_microbatches}"
    microbatch_size = batch_size // num_microbatches
    microbatches = jnp.reshape(x, (num_microbatches, microbatch_size, *x.shape[1:]))
    inputs = jnp.concatenate(  # Add zeros for unused computation blocks in first stage.
            jnp.zeros((num_stages - 1, *microbatches.shape[1:]), dtype=microbatches.dtype),
    state = jnp.zeros_like(microbatches[0])
    num_iterations = inputs.shape[0]
    # Run loop over pipeline steps.
    _, outputs = nn.scan(
        variable_broadcast={"params": True},
        split_rngs={"params": False, "dropout": True},
    )(module, state, inputs)
    # Take last N outputs (first ones are zeros from unused computation blocks in last stage).
    outputs = jnp.concatenate(outputs[-num_microbatches:], axis=0)
    return outputs

We can now use this pipeline function to define the model. First, we write a small module wrapper, that creates a module and executes it using the execute_pipeline function.

class PipelineModule(nn.Module):
    model_axis_name: str
    num_microbatches: int
    module_fn: Callable[..., nn.Module]

    def __call__(self, *args, **kwargs):
        module = self.module_fn()
        return execute_pipeline(

Before we can use the pipeline model, we need to shard the parameters. As for the fully-sharded data parallel, we do this by wrapping the module in a nn.map_variables, in which we use our two previous functions stack_params and unstack_params to shard the parameters over the model axis. We also need to initialize the parameters on each device differently, which we do by folding the RNG of the parameters over the model axis. With that, each device uses a different RNG key, and thus generates different parameters. Since this is a module that can be used for almost any layer that should be sharded across the model axis, we refer to it as a ModelParallelismWrapper:

class ModelParallelismWrapper(nn.Module):
    """Wrapper for adding model parallelism to a module.

    This wrapper adds sharding over the model axis to the parameters of the module
    and initializes the module with different parameters across the model axis.

        model_axis_name: Name of the model axis to shard over.
        module_fn: Function that returns the Flax module to wrap.
        mask_except_model_idx: If not None, only the `mask_except_model_idx`-th shard will be non-zero.
        split_rngs: If True, split the random number generators across the model axis.
        module_kwargs: Additional keyword arguments to pass to the module function.

    model_axis_name: str
    module_fn: Callable[..., nn.Module]
    mask_except_model_idx: int | None = None
    split_rngs: bool = True
    module_kwargs: FrozenDict[str, Any] = FrozenDict({})

    def __call__(self, *args, **kwargs):
        if self.is_initializing() and self.split_rngs:
            # Initialize each module across the model axis with different parameters.
            self.scope.rngs["params"] = self.scope.rngs["params"].replace(
                rng=fold_rng_over_axis(self.scope.rngs["params"].rng, self.model_axis_name)
        # Wrap variables in nn.Partitioned objects to add sharding over the model axis.
        module = nn.map_variables(
            trans_in_fn=functools.partial(unstack_params, axis_name=self.model_axis_name),
        return module(

Combine Full Model with Pipeline

The pipeline structure assumes that each stage has the same layers and structures. However, commonly, we have an input layer, mapping the input to the first stage, and an output layer, mapping the output of the last stage to the final output. We can easily combine the pipeline model with these layers by using the pipeline model as a sub-module. We can then define the input and output layers as usual, and use the pipeline model to process the intermediate features.


Depending on the cost of the input and output layers, we can also consider to split them across devices. This is particularly of interest if the communication cost is comparatively low on the available devices. For instance, we could perform a data-parallel strategy over the model axis for the input layer and gather all outputs on the first device before executing the pipeline. However, for setups like language models, where the input layer mainly consists of an embedding lookup, the communication between devices may become more expensive than performing the lookup on a single device, or we may consider a tensor parallel approach (more on it in the next notebook). For simplicity, in this notebook, we will duplicate the computation of the input and output layers on all devices, and ignore the outputs of the input layers on all devices except the first one. Duplicating the weights across model devices would lead to unnecessary communication overhead during the optimization step, and we instead set all weights that are unused on certain devices to zero. Hence, the optimizer step will not change the weights on the unused devices, and we can avoid the communication overhead. We have already implemented this strategy in the ModelParallelismWrapper by supporting the mask_except_model_idx argument. For the input layer, we mask all model devices besides the first one, and for the output layer, we mask all model devices besides the last one.

With this wrapper in place, we can now define the full model, and use it to train the model on a simple task. We will use a simple MLP model to classify random data, and show how to train the model using the pipeline wrapper. Thus, our full model will consist of a linear input layer, the pipeline model, and a final norm plus the linear output layer. For language models, the input layer may be an embedding layer combined with positional embeddings, and the output layer may be the same as shown here, just applied on a per-token basis.

class PPClassifier(nn.Module):
    config: ConfigDict
    pipeline_module_class: Callable[..., nn.Module] = PipelineModule

    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        # Input layer. Only needed in the first stage.
        x = ModelParallelismWrapper(
        # Pipeline
        stage_module_fn = functools.partial(
            MLPLayers, config=self.config, train=train, name="mlp_layers"
        pipeline_module_fn = functools.partial(
        module = ModelParallelismWrapper(
        x = module(x)
        # Output layer. Only needed in the last stage.
        output_wrapper = functools.partial(
            mask_except_model_idx=self.config.model_axis_size - 1,
        x = output_wrapper(
            module_fn=functools.partial(nn.LayerNorm, dtype=self.config.dtype), name="output_norm"
        x = output_wrapper(
                nn.Dense, features=self.config.num_classes, dtype=self.config.dtype
        x = x.astype(jnp.float32)
        return x


With the model defined, we can now implement the initialization and training step. Most of the functions will be very similar to our previous notebook on data parallelism, since the model parallelism in handled within the model. This also suggest a simple composition of parallelization strategies, which we will further explore in a later notebook. For now, we will focus on the pipeline parallelism with simple data parallelism over the batch axis.

Let’s start with defining the basic config of our model below. Feel free to adjust the parameters and experiment with different settings.

data_config = ConfigDict(
model_config = ConfigDict(
model_config.num_layers //= model_config.model_axis_size  # Layers distributed over model axis.
optimizer_config = ConfigDict(
config = ConfigDict(

We can now create our device mesh. By default, we create a 2x4 mesh (for 8 devices), which means that we have a data parallel size of 2 and a model parallel size of 4. Hence, each device will process a batch of half the global size, and the model pipeline will be split into 4 stages.

device_array = np.array(jax.devices()).reshape(-1, config.model_axis_size)
mesh = Mesh(device_array, (config.data_axis_name, config.model_axis_name))
2024-03-07 10:47:01.486665: E external/xla/xla/stream_executor/cuda/] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We then create the model object and the optimizer. We stick with simple Adam in this example, but feel free to change the optimizer setup.

model_pp = PPClassifier(config=model_config)
optimizer = optax.adamw(

For simplicity, we will train the model on a simple random data classification task. This is mainly to demonstrate the pipeline parallelism, and not to achieve state-of-the-art results. In practice, one would instead create a dataset and dataloader at this point, and setup the data prefetching.

rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (,,
        data_labels_rng, (,), 0,

The initialization function follows the same principles as in the previous notebook, creating the parameters via model.init and the optimizer parameters in the TrainState.create.

def init_fn(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
    return state

Before we can run the full initialization, we need to identify the partitioning of the parameters. Since we annotated the partitioning of all parameters via nn.Partitioned in the model, we can obtain the partitioning by calling jax.eval_shape on the init function. This will return the state shapes, as well as the nn.Partitioned parameter leafs. From those, we can read out the partitioning using nn.get_partition_spec. For the initial call, we can leave the out_specs of the shard map empty, since we do not create the actual parameters during shape evaluation.

init_pp_fn = shard_map(
    functools.partial(init_fn, model=model_pp),
    in_specs=(P(), P(config.data_axis_name)),
state_pp_shapes = jax.eval_shape(init_pp_fn, model_init_rng, batch.inputs)
state_pp_specs = nn.get_partition_spec(state_pp_shapes)

Let’s investigate the partitioning of the parameters below.

{'input_dense': {'sharded': {'bias': PartitionSpec('model', None),
                             'kernel': PartitionSpec('model', None, None)}},
 'output_dense': {'sharded': {'bias': PartitionSpec('model', None),
                              'kernel': PartitionSpec('model', None, None)}},
 'output_norm': {'sharded': {'bias': PartitionSpec('model', None),
                             'scale': PartitionSpec('model', None)}},
 'pipeline': {'sharded': {'mlp_layers': {'block': {'input_dense': {'bias': PartitionSpec('model', None, None),
                                                                   'kernel': PartitionSpec('model', None, None, None)},
                                                   'output_dense': {'bias': PartitionSpec('model', None, None),
                                                                    'kernel': PartitionSpec('model', None, None, None)},
                                                   'pre_norm': {'bias': PartitionSpec('model', None, None),
                                                                'scale': PartitionSpec('model', None, None)}}}}}}

We can see that all parameters are partitioned over the model axis, as we expect. Note that if we would have performed data parallelism over the model devices in the input and output layers, those would not be partitioned over the model axis. Similarly, if we would have used FSDP, the parameters would be partitioned over the data axis as well. Finally, the pipeline parameters are partitioned over the model axis on their first axis, and have the scan axis as their second axis (i.e. the layer axis). This is why the biases are three dimensional (model, layers, features), and the weights are four dimensional (model, layers, features, features).

With the partitioning in place, we can now perform the full initialization of the model and optimizer.

init_pp_fn = jax.jit(
        functools.partial(init_fn, model=model_pp),
        in_specs=(P(), P(config.data_axis_name)),
state_pp = init_pp_fn(model_init_rng, batch.inputs)

Let’s inspect once more the shapes of the parameters to ensure that the initialization was successful.

    jax.tree_map(lambda x: x.shape, state_pp.params["pipeline"]["sharded"]["mlp_layers"]["block"])
{'input_dense': {'bias': Partitioned(value=(4, 2, 512),
                                     names=('model', None, None),
                 'kernel': Partitioned(value=(4, 2, 512, 512),
                                       names=('model', None, None, None),
 'output_dense': {'bias': Partitioned(value=(4, 2, 512),
                                      names=('model', None, None),
                  'kernel': Partitioned(value=(4, 2, 512, 512),
                                        names=('model', None, None, None),
 'pre_norm': {'bias': Partitioned(value=(4, 2, 512),
                                  names=('model', None, None),
              'scale': Partitioned(value=(4, 2, 512),
                                   names=('model', None, None),

As we expected, the first axis of the parameters is the model axis, thus being the same size as the model parallel size. The second axis is the layer axis, with the model in default configuration having 2 layers per stage, i.e. 8 layers in total.

We can also check that each model device has initialized its parameters differently, by comparing the parameters on different devices.

        :, :, 0, 0
Array([[ 0.01044598, -0.07416785],
       [-0.04605146,  0.0008348 ],
       [-0.00904123, -0.00018691],
       [ 0.00661926, -0.06117292]], dtype=float32)

The printed parameter values above are indeed different for each device, since different RNG keys were used for the initialization of the parameters on each device.

Additionally, we check the input and output layers to ensure that they are masked correctly.

print("Input Layer")
pprint(state_pp.params["input_dense"]["sharded"]["kernel"].value[:, 0, 0])
print("\nOutput layer")
pprint(state_pp.params["output_dense"]["sharded"]["kernel"].value[:, 0, 0])
Input Layer
Array([-0.0754908,  0.       ,  0.       ,  0.       ], dtype=float32)

Output layer
Array([ 0.        ,  0.        ,  0.        , -0.07138917], dtype=float32)

The input layer only has non-zero weights on the first element of the first axis, which corresponds to the first model device. For the output layer, we have the last element which is non-zero, corresponding to the last model device/pipeline stage. This is as expected, and completes our check of the initialization.

Training with Pipeline Parallelism

With the model and optimizer initialized, we can now define the training step and train the model. The training step is very similar to the previous notebook, with the main difference being that we now consider only the last model device to calculate the loss. Note that, for more expensive output layers, one could also consider to scatter the pipeline outputs over the model axis and calculate the loss of each sub-batch on all devices and average the results. This would lead to a more balanced computation load, but also to a higher communication overhead. For simplicity, we will stick with the last device for the loss calculation, such that we can ignore the losses on the other devices.

Another small difference is that we split the dropout RNG over the model axis, such that each device uses a different RNG key for the dropout. This is done by folding the RNG key over both the data and model axis. This way, each device uses a different RNG key for the dropout, and thus generates different dropout masks. For other random operations, we may want to fold the RNG key only over one of the two axes, depending on the operation and the desired behavior.

def loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, (config.data_axis_name, config.model_axis_name))
    # Remaining computation is the same as before for single device.
    logits = apply_fn(
        {"params": params},
        rngs={"dropout": dropout_rng},
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    # Mask out loss and accuracy for pipeline stages except last one.
    model_idx = jax.lax.axis_index(config.model_axis_name)
    model_size = jax.lax.psum(1, config.model_axis_name)
    loss = jnp.where(model_idx != model_size - 1, 0.0, loss)
    correct_pred = jnp.where(model_idx != model_size - 1, False, correct_pred)
    batch_size = jnp.where(model_idx != model_size - 1, 0, batch_size)
    # Collect metrics and return loss.
    step_metrics = {
        "loss": (loss.sum(), batch_size),
        "accuracy": (correct_pred.sum(), batch_size),
    loss = loss.mean()
    return loss, step_metrics

The training step is also very similar as before. While we support gradient accumulation, it is recommended to integrate those minibatches into the pipeline. This is because the pipeline parallelism improves efficiency with increasing number of micro-batches, and thus we want to keep the pipeline busy as much as possible.

def train_step_pp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
    # Update parameters. We need to sync the gradients across data devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = sync_gradients(grads, (config.data_axis_name, config.model_axis_name))
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas (both model and data axes).
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=(config.data_axis_name, config.model_axis_name)),
    if metrics is None:
        metrics = step_metrics
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

Finally, we can compile the training step. As before, we first use the jax.eval_shape function to find the shapes of the metrics we want to keep track of. We then initialize those metrics, and compile the training step using the jax.jit function.

train_step_pp_fn = jax.jit(
        in_specs=(state_pp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_pp_specs, P()),
    donate_argnames=("state", "metrics"),
_, metric_shapes = jax.eval_shape(
metrics_pp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_pp, metrics_pp = train_step_pp_fn(state_pp, metrics_pp, batch)
/home/plippe/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/ UserWarning: Some donated buffers were not usable: ShapedArray(float32[1,512]), ShapedArray(float32[1,784,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,512,10]), ShapedArray(float32[1,512]), ShapedArray(float32[1,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,512]), ShapedArray(float32[1,784,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,512,10]), ShapedArray(float32[1,512]), ShapedArray(float32[1,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,512]), ShapedArray(float32[1,784,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,512,10]), ShapedArray(float32[1,512]), ShapedArray(float32[1,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512,512]), ShapedArray(float32[1,2,512]), ShapedArray(float32[1,2,512]).
See an explanation at
  warnings.warn("Some donated buffers were not usable:"

As a reference, we print the number of parameters of the model. Since we are running on CPU, we design the model extra small.

print(f"Number of parameters: {get_num_params(state_pp):_}")
Number of parameters: 5_842_984

Let’s check if our pipeline training step is working as expected by running it for a few steps on the randomized classification task.

for _ in range(15):
    state_pp, metrics_pp = train_step_pp_fn(state_pp, metrics_pp, batch)
final_metrics_pp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_pp, final_metrics_pp = train_step_pp_fn(state_pp, final_metrics_pp, batch)
print_metrics(final_metrics_pp, title="Final Metrics - Pipeline")
 Final Metrics - Pipeline
accuracy: 1.000000
loss: 0.000185

As we can see, the model is training as expected and achieves a low loss after very few steps. This is mainly due to the simplicity of the task, and not due to the pipeline parallelism. However, the pipeline parallelism is working as expected, and the model is training on all devices in parallel. We will perform a closer test at the end of the notebook to verify that the model parallelized across devices works identically to the non-parallelized single-device model.

Intermediate Summary

In this notebook, we have discussed and implemented pipeline parallelism with micro-batching. We have shown how to split the model across stages, and how to implement the pipeline parallelism in JAX. We have also shown how to combine the pipeline model with input and output layers, and how to initialize and train the model. We have also discussed the trade-offs of the micro-batching strategy, and how to choose the micro-batch size. In the next part, we will implement another method to mitigate the pipeline bubble problem, namely looping pipelines.

References and Resources

[Huang et al., 2019] Huang, Y., Cheng, Y., Bapna, A., Firat, O., Chen, D., Chen, M., Lee, H., Ngiam, J., Le, Q.V. and Wu, Y., 2019. Gpipe: Efficient training of giant neural networks using pipeline parallelism. Advances in neural information processing systems, 32. Paper link

[Narayanan et al., 2021] Narayanan, D., Shoeybi, M., Casper, J., LeGresley, P., Patwary, M., Korthikanti, V., Vainbrand, D., Kashinkunti, P., Bernauer, J., Catanzaro, B. and Phanishayee, A., 2021, November. Efficient large-scale language model training on gpu clusters using megatron-lm. In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis (pp. 1-15). Paper link

[Lamy-Poirier, 2023] Lamy-Poirier, J., 2023. Breadth-First Pipeline Parallelism. Proceedings of Machine Learning and Systems, 5. Paper link

[McKinney, 2023] McKinney, A., 2023. A Brief Overview of Parallelism Strategies in Deep Learning. Blog post link

[Huggingface, 2024] Huggingface, 2024. Model Parallelism. Documentation link

[DeepSpeed, 2024] DeepSpeed, 2024. Pipeline Parallelism. Documentation link

Star our repository If you found this tutorial helpful, consider ⭐-ing our repository.
Ask questions For any questions, typos, or bugs that you found, please raise an issue on GitHub.