Part 2.2: (Fully-Sharded) Data Parallelism

Filled notebook: View filled on Github Open filled In Collab

Author: Phillip Lippe

In the following series of tutorials, we will explore different parallelism strategies for training large deep learning models. Our focus will be on three common parallelism strategies: data parallelism, pipeline parallelism, and tensor parallelism. Data parallelism, as the name suggests, focuses on parallelizing the data processing of the model. If we are given a very large batch, we divide the batch into smaller batches and distribute them across multiple devices. Each device will process a different batch of data in parallel. Afterwards, we will aggregate the results from each device to update the model. Data parallelism is the most common parallelism strategy used in deep learning and well supported in most frameworks. Thus, we will start with data parallelism in this tutorial. In later tutorials, we will explore pipeline and tensor parallelism which focus on parallelizing the computation of the model itself. A short overview of the three parallelism strategies is shown in the figure below.

0905332a6070440e88f98b1214ffa348

We will focus on implementing data parallelism in JAX, but a lot of the concepts can be easily transferred to other frameworks like PyTorch or TensorFlow. With distributed computing introduced in Part 2.1, we can now implement a simple data parallelism strategy to train a small neural network on multiple devices. We then discuss fully-sharded data parallelism (FSDP), which distributes the model parameters across multiple devices and reduces memory consumption (also known as part of the ZeRO optimizer).

Prerequisites

First, let’s start with setting up the basic environment and utility functions we have seen from previous notebooks. We download the python scripts of the previous notebooks below. This is only needed when running on Google Colab, and local execution will skip this step automatically.

[1]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

As in the previous part, we set up the notebook such that it can be run on a CPU with 8 simulated devices and no expensive hardware is required. If you are running on Google Colab, you do not need to select a GPU runtime, as we will simulate multiple devices on a single CPU. If you are running on your local machine and have multiple GPUs available, you can comment out the line below and use set_XLA_flags_gpu instead to set the XLA flags we have seen in the previous parts.

[2]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

With the environment variables set, we can import our required libraries and start with the implementation.

[3]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

We again use some small utilities from the previous notebook (Part 1.1) to reduce the code duplication in this notebook.

[4]:
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics

Additionally, in Part 2.1, we have already implemented the fold_rng_over_axis, which allows us to create mesh-axis specific RNGs. In this notebook, we will use this utility in our data parallelism implementation.

[5]:
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
    """Folds the random number generator over the given axis.

    This is useful for generating a different random number for each device
    across a certain axis (e.g. the model axis).

    Args:
        rng: The random number generator.
        axis_name: The axis name to fold the random number generator over.

    Returns:
        A new random number generator, different for each device index along the axis.
    """
    axis_index = jax.lax.axis_index(axis_name)
    return jax.random.fold_in(rng, axis_index)

Data Parallelism

In data parallelism, we aim to use our multiple devices to increase our batch size. Each device will hold the same model and parameters, and process a different batch of data in parallel. After processing the data and obtaining the gradients for each batch, we aggregate the gradients over the devices and update our model. The main advantage of data parallelism is that it is easy to implement and scales well with the number of devices, since the devices need to communicate only once per batch. However, the main disadvantage is that the model size is limited by the memory of a single device, which can be a bottleneck for very large models and we will discuss how to overcome this in the next section.

For now, let’s start with plain data parallelism. By using shard map, we can focus on writing single-device code and shard map will take care of the parallelization. Hence, we can simply write our model and training loop as if we would run it on a single device. We use our example, small classifier from the previous tutorial below:

[6]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

We again specify all hyperparameters in a config. Feel free to adjust the hyperparameters to your needs. As only addition to the previous tutorial, we add the data_axis_name attribute, which will denote the mesh axis over which we want to perform data parallelism. We will use this attribute in the following sections to coordinate communications over the data axis.

[7]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=10,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

Initialization

We start by initializing the model and optimizer. Also here, we can continue to write the initialization as if we would run it on a single device. We create the objects below:

[8]:
model_dp = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

We also create an example batch of data to test the model. Since shard map will take care of the parallelization and sharding of the inputs, we can simply create the batch as if we would run it on a single device:

[9]:
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)
2024-03-07 10:46:28.299510: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:273] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Each device will process a batch size of config.data.batch_size // num_devices. In a full training run, you may want to prefetch the array to the devices in the correct sharding to avoid the initial transfer time. This can be done via flax.jax_utils.prefetch_to_device, which supports async placement of arrays on devices, and is especially useful for large arrays in a dataset. However, for the purpose of this tutorial, we will keep the array on the first device and let shard map take care of the transfer.

We can now write an initialization function. Again, this is the same function as we have seen for single-device training:

[10]:
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

Before we can execute the initialization, we need to define the mesh and wrap the initialization function with shard_map. We define the mesh below:

[11]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, (config.data_axis_name,))

We create a single-dimensional mesh, with all devices along the data axis. Hence, we will perform data parallelism over all devices.

For wrapping the initialization function with shard_map, we need to specify the mesh and the sharding specifications for the input and output. For generating the parameters, we input an RNG which we want to replicate across all devices. This ensures that all devices have the same parameters. The input batch on the other hand will be sharded over the data axis. As an output, we expect a train state which is identical on all devices. We wrap the initialization function with shard_map below:

[12]:
init_dp_fn = jax.jit(
    shard_map(
        functools.partial(init_dp, model=model_dp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=P(),
        check_rep=False,
    ),
)

The jitting is optional, but may be necessary in models where we make use of jax.lax.axis_index or other operations that are only supported within a jitted function of shard map. We can now execute the initialization function and check the resulting train state:

[13]:
state_dp = init_dp_fn(model_init_rng, batch.inputs)
print("DP Parameters")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))
DP Parameters
{'input_dense': {'bias': ((512,), GSPMDSharding({replicated})),
                 'kernel': ((784, 512), GSPMDSharding({replicated}))},
 'output_dense': {'bias': ((10,), GSPMDSharding({replicated})),
                  'kernel': ((512, 10), GSPMDSharding({replicated}))}}

We find all parameters have the expected shape and are replicated across all devices. We can now move on to the training loop.

Train Step

We can write the train step almost as if we would run it on a single device. The only difference is the dropout RNG. As mentioned before, we want each device to use a different dropout mask. However, the RNG in the train state is replicated across all devices. We can use our function fold_rng_over_axis to split the RNG key across devices on the data axis. This device-specific key can then be passed to the dropout layers. We implement the train step below:

[14]:
def loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
    # Remaining computation is the same as before for single device.
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": dropout_rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In the previous tutorial, we explored the option of accumulating the gradients of multiple batches before updating the model, to reduce memory footprint. While in data parallelism, we usually can already run a much larger batch due to parallelization across multiple devices, we can still use this option to further increase the batch size. Since the forward step and backward pass do not require any communication between devices, we can use the same function accumulate_gradients as for the single-device training and scan over smaller splits of the batch per device. We can then accumulate the gradients over subbatches on each device, and only communicate the final aggregated gradients to update the model.

In the train step, we need to communicate the gradients over the data axis. After obtaining the gradients and metrics per device, we want to average the gradients over all devices and update the model. We can use jax.lax.pmean to average the gradients over the data axis, and then apply the optimizer step to update the model on each device in parallel. Similarly, we use jax.lax.psum to sum the statistics in the metrics over the data axis. Alternatively, we could sync the metrics only before we want to log them, to reduce the communication overhead. However, since the metrics are usually just a few scalars compared to millions or billions of parameters, the communication overhead is usually negligible. The returned state and metrics will be the same on all devices, and we can use them to continue the training loop.

[15]:
def train_step_dp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name=config.data_axis_name), grads)
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

We can now wrap the train step with shard_map and jit it. We need to specify the mesh and the sharding specifications for the input and output. The input batch will be sharded over the data axis. All other arrays, i.e. the input/output train state and metrics, will be replicated across all devices. Further, we can specify the state and metrics as donatable to avoid unnecessary memory overhead. We wrap the train step with shard_map below:

[16]:
train_step_dp_fn = jax.jit(
    shard_map(
        train_step_dp,
        mesh,
        in_specs=(P(), P(), P(config.data_axis_name)),
        out_specs=(P(), P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)

Before running the training step, we want to find the shape and PyTree structure of the metrics. We use jax.eval_shape or flax.jax_utils.partial_eval_by_shape to find the shape of the metrics without running the full computation. Otherwise, we would need to compile the training step twice, once for the input metrics being None, and once for the input metrics being the correct shape. We avoid this overhead by the shape evaluation below:

[17]:
_, metric_shapes = jax.eval_shape(
    train_step_dp_fn,
    state_dp,
    None,
    batch,
)
metrics_dp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

Finally, we can run the training loop for a few steps and check the resulting metrics:

[18]:
for _ in range(15):
    state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
final_metrics_dp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_dp, final_metrics_dp = train_step_dp_fn(state_dp, final_metrics_dp, batch)
print_metrics(final_metrics_dp)
accuracy: 1.000000
loss: 0.003343

As we can see, the model is training as expected and is able to overfit on the single batch of data. Let’s once more check how the parameters are distributed across the devices:

[19]:
print("DP Parameters")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))
print("Metrics")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), final_metrics_dp))
DP Parameters
{'input_dense': {'bias': ((512,), GSPMDSharding({replicated})),
                 'kernel': ((784, 512), GSPMDSharding({replicated}))},
 'output_dense': {'bias': ((10,), GSPMDSharding({replicated})),
                  'kernel': ((512, 10), GSPMDSharding({replicated}))}}
Metrics
{'accuracy': (((), GSPMDSharding({replicated})),
              ((), GSPMDSharding({replicated}))),
 'loss': (((), GSPMDSharding({replicated})), ((), GSPMDSharding({replicated})))}

Both parameters and metrics are still replicated over all devices, suggesting our implementation works as expected. We can now move on to the next section.

Parameter Sharding

So far, we have implemented data parallelism where the model parameters are replicated across all devices. This is a simple and effective way to parallelize the training of a model. However, as the model size continues to grow, the memory consumption of the model parameters can become a bottleneck. For instance, if we have a model with 1 billion parameters and 8 devices, each device would need to hold the full 1 billion parameters. With each parameter being in float32 and using an optimizer like Adam with two additional float32 values per parameter (first- and second-order momentum), the model itself already takes up 12GB of the device memory. This can be a problem for devices with limited memory, like taking already half the memory of an A5000 GPU with 24GB. To overcome this, we can shard the model parameters across multiple devices. This is known as fully-sharded data parallelism (FSDP) and is part of the ZeRO optimizer. In FSDP, each device holds a different part of the model parameters and their optimizer state. Before executing a layer or module, we communicate the respective parameters across devices such that we can process the data in parallel. In the backward pass, we communicate the gradients back to the respective devices and update the model. This way, we can significantly reduce the memory consumption of the model parameters and scale to very large models. The ZeRO optimizer further extends this concept by allowing for different stages of sharding, such as only sharding the optimizer state (\(P_{os}\)), sharding the gradients as well (\(P_{os+g}\)), or sharding the model parameters as well (\(P_{os+g+p}\), FSDP). The figure below shows the different stages of ZeRO (figure credit: Rajbhandari et al., 2019). Here, \(\Psi\) represents the number of parameters in the model, and \(K\) the number of bytes in the optimizer state per parameter. In the mixed precision training setup shown, we use bfloat16 for the model parameters (2 bytes) and the gradients (2 bytes), and float32 for the Adam optimizer state (8 bytes) with a float32 copy of the parameters (4 bytes). Thus, our memory is 16 bytes per parameter. By using different stages of ZeRO, we can reduce the memory consumption of the model parameters and optimizer state by up to \(1/N_d\) where \(N_d\) is the number of devices.

eea25716792042ef99806856593a91a6

We will focus on the basic concept of FSDP in this tutorial, but the other ZeRO stages can also be implemented in similar ways. We will start by sharding the model parameters across multiple devices and then implement the forward and backward pass accordingly. We will also discuss the communication of the optimizer state and gradients in the backward pass.

Sharding setup

We start by sharding the model parameters across multiple devices. The strategy we follow is as follows: - During initialization, we create the full parameters on each device. - We then use jax.lax.axis_index to split the parameters across devices and only keep a shard of the parameters on each device. - We annotate the parameters with the sharding specification, so we know how to put the parameters back together in the forward pass.

For annotating the sharding specification of parameters, Flax provides a wrapper called nn.Partitioned (docs). This wrapper takes as input a parameter and a list of axis names, similar to the PartitionSpec we have seen before. We can then use the parameter as before, but can use the axis name specification to determine the communication needed between devices, the shard map specification, and more. Further, whenever we apply transformations on modules with partitioned parameters, e.g. vmap or scan, the partitioned parameters will be transformed accordingly and the annotated axis adjusted to the new shapes. From now on, we consider parameters to be either a jax.Array if fully replicated, or a flax.linen.Partitioned if sharded across devices. We create a small type annotation for this case below.

[20]:
Parameter = jax.Array | nn.Partitioned

We now need to write a function that shards the parameters across devices for FSDP. We can take as input a PyTree of parameters and a mesh axis name, and determine for each parameter the sharding specification. To keep the global shape of a parameter untouched, we look for an axis which can be equally split across the number of devices present. If multiple are present, we select the largest possible axis, to keep the sharding consistent when varying the number of devices. We then use this axis as the sharding axis, and wrap the parameter in a nn.Partitioned with the axis name and position. For some parameters that are very small, sharding may not be beneficial since the communication time may outweigh the memory costs. For instance, we are likely interested in sharding large weight matrices while the bias and scaling parameter in a normalization layer have negligible memory costs. We can specify a minimum size for sharding, and only shard parameters that are larger than this size. We implement the function below:

[21]:
@jax.named_scope("shard_params")
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
    """Shard parameters across the given mesh axis.

    Args:
        params: The parameters to shard.
        axis_name: The axis to shard parameters across.
        min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

    Returns:
        PyTree of same structure as params, but with leaves sharded over new axis if possible.
    """
    axis_idx = jax.lax.axis_index(axis_name)
    axis_size = jax.lax.psum(1, axis_name)

    def _split(x: Parameter) -> Parameter:
        if isinstance(x, nn.Partitioned):
            value, names = x.value, x.names
        else:
            value = x
            names = (None,) * value.ndim
        if axis_name in names:
            logging.warning(
                f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
            )
            return x
        elif value.size <= min_weight_size:
            logging.info(
                f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
            )
            return x
        else:
            shape = value.shape
            idx = np.argsort(shape)[::-1]  # Shard along largest possible axis.
            for i in idx:
                if shape[i] % axis_size == 0 and names[i] is None:
                    split_size = shape[i] // axis_size
                    p_sharded = nn.Partitioned(
                        value=lax.dynamic_slice_in_dim(  # Shard to keep on present device.
                            value, axis_idx * split_size, split_size, axis=i
                        ),
                        names=names[:i] + (axis_name,) + names[i + 1 :],
                    )
                    return p_sharded
            logging.warning(
                f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
            )
            return x

    return jax.tree_util.tree_map(
        _split,
        params,
        is_leaf=lambda x: isinstance(
            x, nn.Partitioned
        ),  # Consider a nn.Partitioned object as a leaf.
    )

In the split function, we check for each parameter wether it has been already sharded over the data axis. This case should usually not occur, since we cannot shard a parameter twice over the same mesh axis (otherwise information must get lost). However, it can occur if we shard a parameter over the data axis, and then want to shard it over another axis. Hence, we allow for nn.Partitioned objects with other axis names, and find a new axis to shard over. We then shard the parameter over the new axis and return the new nn.Partitioned object. Note that the logging statements are only evaluated during the first run of the function when jitted, and can give us a hint when something went wrong.

With the parameters sharded, we now need to write a function to gather the parameters back to a single device. This is necessary for the forward pass, where we want to compute the output of the model on a single device. We can use jax.lax.all_gather to gather the parameters from all devices and concatenate them along the sharding axis. In the backward pass, we will use jax.lax.psum_scatter to scatter the gradients back to the respective devices. However, one subtle difference to a non-sharded parameter is that in our previous data parallelism training step, we averaged the gradients of different devices. The jax.lax.psum_scatter would instead sum the gradients over the devices. To keep the behavior consistent, we adjust the backward pass to average the gradients over the devices. We implement the adjusted gather operation below:

[22]:
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
    """Gathering with averaging gradients across replicas."""
    axis_size = jax.lax.psum(1, axis_name)

    # Define a custom gradient for the gather operation.
    @jax.custom_gradient
    def f(x):
        def grad_fn(g):
            # pmean_scatter
            return (
                jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
            )

        return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn

    return f(x)

We can now write a function that takes as input a PyTree of the sharded parameters and gathers them back to a single device. This will be our reverse operation of shard_params function. We implement the function below:

[23]:
@jax.named_scope("gather_params")
def gather_params(params: PyTree, axis_name: str) -> PyTree:
    """Gather parameters from all replicas across the given axis.

    Args:
        params: The parameters to gather.
        axis_name: The axis to gather parameters across.

    Returns:
        PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
    """

    def _gather(p: Parameter) -> Parameter:
        if isinstance(p, nn.Partitioned) and axis_name in p.names:
            param_shard = p.names
            shard_axis = param_shard.index(axis_name)
            value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
            # If there are any other axes that are sharded, we need to keep the partitioned structure.
            # Otherwise, we can return the value directly.
            param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
            if any([name is not None for name in param_shard]):
                return nn.Partitioned(value, param_shard)
            else:
                return value
        else:
            return p

    return jax.tree_util.tree_map(_gather, params, is_leaf=lambda x: isinstance(x, nn.Partitioned))

Whenever we call a module, we want to gather the parameters back to a single device, and then shard them again once the module is done computing (i.e. free the memory of the full parameters). For this, we wrap a module into a nn.map_variables transformation, which allows for transformations on the parameter before and after the module is called. We can use this to gather the parameters before the module is called, and shard them again after the module is done. We implement the function below:

[24]:
def shard_module_params(
    target: nn.Module | Callable, axis_name: str, min_weight_size: int = 2**18
) -> nn.Module | Callable:
    """Shard parameters of a module across replicas.

    Args:
        target: The module to shard.
        axis_name: The axis name to shard parameters across.
        min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

    Returns:
        The module with sharded parameters.
    """
    return nn.map_variables(
        target,
        trans_in_fn=functools.partial(gather_params, axis_name=axis_name),
        trans_out_fn=functools.partial(
            shard_params, axis_name=axis_name, min_weight_size=min_weight_size
        ),
        mapped_collections="params",
        mutable=True,
    )

Another aspect in the design of FSDP is how we deal with the gathered parameters during the training step. If we use the function as is, we would gather the parameters and we keep the full parameters as intermediate arrays on a single device until the backward step is completed. We may have situations where the full parameters do not fit on a single device. In that case, we can remat the gather operation, such that after the forward pass of a module, we free up the memory of the full parameters of the module, and only keep the sharded parameters. We can then gather the parameters again before the backward pass, and scatter the gradients back to the respective devices. This can be done by using the remat transformation directly on gather_params, or by rematting the whole forward pass with a policy to remat all all_gather operations: @partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') (see official documentation for more details). The final option is to remat the whole module, which we have seen in the previous tutorial. In that case, it matters whether we shard the parameters of every individual module, or the global module, since in the individual case, we also remat the gather, while in the other case, we only remat the forward pass. All these options trade off memory and communication cost, and need to be chosen based on the hardware at hand. We will provide both options in our code, such that depending on the model size, the corresponding setup can be chosen.

Let’s now use the sharding function to shard our simple classifier below:

[25]:
class FSDPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        sharded_dense = shard_module_params(
            nn.Dense,
            axis_name=self.config.data_axis_name,
            min_weight_size=self.config.min_weight_size,
        )
        x = sharded_dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = sharded_dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

Note that we only shard the dense layers here, since all other layers (activation function, dropout) do not have any parameters.

Initialization

We can now initialize the model. First, we set the min_weight_size at which we want to shard the parameters. The original value of \(2^{18}\) is chosen based on parameters being greater than 1MB (assuming 4 bytes per parameter), inspired by the Big Vision repository. However, for demonstration purposes, we set it much lower here such that some weights are sharded for our small example model.

[26]:
config.model.min_weight_size = 2**4

The model is then created as usual. If we want to shard the parameters of the whole model, we would use model_fsdp = shard_module_params(FSDPClassifier, axis_name=config.data_axis_name, min_weight_size=config.model.min_weight_size)(config.model) instead.

[27]:
model_fsdp = FSDPClassifier(config=config.model)

While we can reuse our initialization function, we need to adjust the shard map. This is because some parameters are now sharded across devices, and we need to specify the mesh and the sharding specifications for the input and output. The partitioning is determined within the model initialization, so we cannot manually specify it as before. Instead, we can first wrap the initialization function with shard_map and an unknown output specification (i.e. simply set to all replicated P() and check_rep to False), and evaluate the shapes which is independent of the output specification. Since the shapes are also determined for the partitioned parameters, we can then use the shapes to determine the sharding specification. An easy way of doing that is by using nn.get_partition_spec which returns the sharding specification of a PyTree with nn.Partitioned leafs. We can then use this specification to wrap the initialization function with shard_map and the correct output specification. We implement the function below:

[28]:
init_fsdp_fn = shard_map(
    functools.partial(init_dp, model=model_fsdp),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)

We can check the specification before plugging it into the shard_map function. Since the output of the initialization function is a train state, state_fsdp_specs is a train state as well, but with each element being a partition spec instead. We can print out the specs below.

[29]:
print("RNG", state_fsdp_specs.rng)
print("\nParameters")
pprint(state_fsdp_specs.params)
print("\nOptimizer state")
pprint(state_fsdp_specs.opt_state[0])
RNG PartitionSpec()

Parameters
{'input_dense': {'bias': PartitionSpec('data',),
                 'kernel': PartitionSpec('data', None)},
 'output_dense': {'bias': PartitionSpec(),
                  'kernel': PartitionSpec('data', None)}}

Optimizer state
ScaleByAdamState(count=PartitionSpec(), mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}})

The random number generator is replicated across devices as before. The parameters have now different shardings. The kernels of both dense layers are sharded over the data axis, while the output bias is replicated across all devices due to its small size (only 10 parameters). Further, the optimizer state follows the same sharding as the parameters. The mu (first order momentum) and nu (second order momentum) are sharded over the data axis like the parameters, while the step (step counter) is replicated across all devices. We can now wrap the initialization function with the correct output specification and execute it.

[30]:
init_fsdp_fn = jax.jit(
    shard_map(
        functools.partial(init_dp, model=model_fsdp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_fsdp_specs,
        check_rep=False,
    )
)
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)

We print the shapes of the parameters to verify that the sharding was successful and their global shape is as we expected.

[31]:
print("FSDP Parameters")
pprint(jax.tree_map(lambda x: x.shape, jax.device_get(state_fsdp.params)))
FSDP Parameters
{'input_dense': {'bias': Partitioned(value=(512,), names=('data',), mesh=None),
                 'kernel': Partitioned(value=(784, 512),
                                       names=('data', None),
                                       mesh=None)},
 'output_dense': {'bias': (10,),
                  'kernel': Partitioned(value=(512, 10),
                                        names=('data', None),
                                        mesh=None)}}

All parameters have the expected global shape, and are sharded on their largest axis if they are larger than the minimum weight size. The mesh attribute in nn.Partitioned can be set to a different mesh if we do not want to shard over the global mesh, as set by the shard_map. In most cases, however, we want to shard over the global mesh, and can leave the mesh attribute as None. We can now move on to the training loop.

Train Step

In the training step, we need to adjust the synchronization of the gradients to take into account the parameter sharding. For a given parameter gradient in our PyTree, we want to average the gradients over a mesh axis if it was not partitioned over it, and leave it otherwise. We implement this strategy in the function below. Note that, in later tutorials, we will be dealing with multiple mesh axes. Hence, the function is written such that it can handle multiple mesh axes, and we can simply pass the mesh axis name over which we want to average the gradients.

[32]:
def sync_gradients(
    grads: PyTree,
    axis_names: Sequence[str],
) -> PyTree:
    """Synchronize gradients across devices.

    Gradients for parameters that are replicated over a given axis are averaged across devices.
    Parameters that are partitioned over a given axis are considered to already have a mean of
    the gradients on each device, and hence do not need to be altered.

    Args:
        grads: The gradients to synchronize.
        axis_names: The axis names to synchronize gradients across.

    Returns:
        The gradients averaged over the specified axes if they are replicated.
    """

    def sync_grad(g: Parameter) -> Parameter:
        if isinstance(g, nn.Partitioned):
            # Tree leaves for flattening potentially nested axis (multiple names can exist for single array axis).
            replication_axis_names = [
                name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
            ]
            if len(replication_axis_names) == 0:
                # Parameters partitioned over all axes.
                return g
            else:
                # Average over remaining replicated axes.
                return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
        else:
            # Parameters are replicated over all axes.
            return jax.lax.pmean(g, axis_name=axis_names)

    return jax.tree_map(sync_grad, grads, is_leaf=lambda x: isinstance(x, nn.Partitioned))

Besides the gradient averaging, we do not need to adjust anything else of our training step. One potential optimization is to cast the parameters to bfloat16 before we execute the apply_fn in the loss_fn. This can reduce the communication overhead, since we only need to communicate half the amount of data. However, this only works if all parameters would need to be in bfloat16, and might be more tricky to handle if we have a mix of parameter precisions in our model. Hence, for simplicity, we will not implement this optimization here, and leave the loss function unchanged.

[33]:
def train_step_fsdp(
    state: TrainState,
    metrics: Metrics,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = sync_gradients(grads, (config.data_axis_name,))
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

We can now wrap the train step with shard_map and jit it. We reuse our sharding specification of the train state, which we determined during the initialization, and specify the mesh and the sharding specifications for the metrics and batch as before.

[34]:
train_step_fsdp_fn = jax.jit(
    shard_map(
        train_step_fsdp,
        mesh,
        in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_fsdp_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
_, metric_shapes = jax.eval_shape(
    train_step_fsdp_fn,
    state_fsdp,
    None,
    batch,
)
metrics_fsdp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

Let’s run it again for a few steps and check the resulting metrics:

[35]:
for _ in range(15):
    state_fsdp, metrics_fsdp = train_step_fsdp_fn(state_fsdp, metrics_fsdp, batch)
final_metrics_fsdp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(state_fsdp, final_metrics_fsdp, batch)
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
/home/plippe/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[64]), ShapedArray(float32[98,512]), ShapedArray(float32[64,10]), ShapedArray(float32[64]), ShapedArray(float32[98,512]), ShapedArray(float32[64,10]), ShapedArray(float32[64]), ShapedArray(float32[98,512]), ShapedArray(float32[64,10]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
 FSDP - Final metrics
accuracy: 1.000000
loss: 0.003343

The model is training as expected and is able to overfit on the single batch of data.

Verify that FSDP gives same results as DP

Since the different parallelism strategies should result in the same model, we can verify that the FSDP model gives the same results as the data parallel model without sharding. Further, the initialization of the both models behave in the exact same way, such that in our example training runs above, they should have resulted in the same outputs. We can verify this by comparing the metrics of the FSDP and DP model below:

[36]:
metrics_dp = jax.device_get(metrics_dp)
metrics_fsdp = jax.device_get(metrics_fsdp)
for key in metrics_dp.keys():
    val_dp = metrics_dp[key][0] / metrics_dp[key][1]
    val_fsdp = metrics_fsdp[key][0] / metrics_fsdp[key][1]
    print(f"Metrics DP Avg {key}: {val_dp:.4f}")
    print(f"Metrics FSDP Avg {key}: {val_fsdp:.4f}")
    np.testing.assert_allclose(val_dp, val_fsdp, atol=1e-2)
Metrics DP Avg accuracy: 0.8844
Metrics FSDP Avg accuracy: 0.8844
Metrics DP Avg loss: 0.5102
Metrics FSDP Avg loss: 0.5102

Both models have resulted in the same metrics, suggesting that the FSDP model is training as expected. We can also compare parameters and optimizer state to verify that they are the same.

[37]:
params_dp = jax.device_get({"params": state_dp.params, "opt_state": state_dp.opt_state})
params_fsdp = jax.device_get({"params": state_fsdp.params, "opt_state": state_fsdp.opt_state})
params_fsdp = jax.tree_map(
    lambda x: x.value if isinstance(x, nn.Partitioned) else x,
    params_fsdp,
    is_leaf=lambda x: isinstance(x, nn.Partitioned),
)
_ = jax.tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-4), params_dp, params_fsdp)
print("Parameters match between DP and FSDP")
Parameters match between DP and FSDP

We find that both methods result in the same output, verifying that we have implemented FSDP correctly. Note that very small differences can still occur due to different reductions in the gradient sync. However, these differences should be negligible and not affect the training of the model.

Conclusion

In this tutorial, we have introduced the basic building blocks of distributed computing in JAX, and implemented data parallelism and fully-sharded data parallelism. In data parallelism, we shard the input batch over the data axis, and the model parameters are replicated across all devices. This allows for a larger batch size and faster training. For large models, we further improved the memory footprint by sharding the model parameters across devices. In this fully-sharded data parallelism (FSDP), we have seen how to shard the model parameters, and how to gather and scatter the parameters and gradients in the forward and backward pass. Data parallelism is one of the most common parallelism strategies in deep learning, and FSDP is a powerful tool to scale to very large models. Still, we may come to situations where the model size is limited by the memory of a single device, and we need expensive remat strategies to overcome this. Alternatively, we can look at other parallelism strategies which shard the model execution itself over devices. These strategies, also called model parallelism, will be the focus in the next tutorials, specifically pipeline parallelism and tensor parallelism. With these, we can scale to even larger models and train billion-parameter models efficiently.

References and Resources

[Rajbhandari et al., 2020] Rajbhandari, S., Rasley, J., Ruwase, O. and He, Y., 2020. Zero: Memory optimizations toward training trillion parameter models. In SC20: International Conference for High Performance Computing, Networking, Storage and Analysis (pp. 1-16). Paper link

[Wang and Komatsuzaki, 2021] Wang, B., and Komatsuzaki, A., 2021. Mesh transformer jax. GitHub link

[Beyer et al., 2022] Beyer, L., Zhai, X., and Kolesnikov, A., 2022. Big Vision. GitHub link

[Google, 2024] JAX Team Google, 2024. Distributed arrays and automatic parallelization. Notebook link

[Google, 2024] JAX Team Google, 2024. SPMD multi-device parallelism with shard_map. Notebook link

[Google, 2024] JAX Team Google, 2024. Using JAX in multi-host and multi-process environments. Notebook link

[DeepSpeed, 2024] DeepSpeed, 2024. Zero Redundancy Optimizer. Tutorial link


Star our repository If you found this tutorial helpful, consider ⭐-ing our repository.
Ask questions For any questions, typos, or bugs that you found, please raise an issue on GitHub.