Part 4.2: Asynchronous Linear Layers with Tensor Parallelism

Filled notebook: View filled on Github Open filled In Collab

Author: Phillip Lippe

In Part 4.1, we implemented a parallel linear layer using tensor parallelism. However, our implementation has one major inefficiency. In the gather strategy, we first need to communicate the features over all devices before we can compute the output. This means that all devices are idle until we finish the communication, and we do not overlap communication with computation. Similarly, in the scatter strategy, we first need to compute the output on all devices before we can communicate the results and sum them. This means that all devices are busy until we finish the computation, and then need to wait for the communication to finish before continuing with subsequent layers. This is a major inefficiency, and we would like to avoid it if possible.

To tackle this challenge, we will implement asynchronous linear layers in this notebook. Using an asynchronous gather and scatter strategy, we will overlap communication with computation, which will allow us to hide the communication latency and improve the overall performance of our model. The techniques we will use in this notebook are inspired by the ViT-22b model, which used these techniques to scale up Vision Transformers. In Part 4.3, we will apply these techniques to scale up our Transformer-based language model.

Prerequisites

First, let’s start with setting up the basic environment and utility functions we have seen from previous notebooks. We download the python scripts of the previous notebooks below. This is only needed when running on Google Colab, and local execution will skip this step automatically.

[1]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "data_parallel.py", "pipeline_parallel.py", "tensor_parallel.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

As before, we simulate 8 devices on CPU to demonstrate the parallelism without the need for multiple GPUs or TPUs. If you are running on your local machine and have multiple GPUs available, you can comment out the lines below.

[2]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

We now import our standard libraries.

[3]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, List, Literal, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

PyTree = Any
Parameter = jax.Array | nn.Partitioned
Metrics = Dict[str, Tuple[jax.Array, ...]]

We also import the utility functions from the previous notebooks. Our notebook will rely on the ModelParallelismWrapper from the pipeline parallelism notebook and functions from Part 4.1. If you are not familiar with those modules, it is recommended to look at the implementation of this module before continuing.

[4]:
from pipeline_parallel import ModelParallelismWrapper
from single_gpu import Batch, print_metrics
from tensor_parallel import (
    MLPBlockInput,
    MLPBlockOutput,
    TPClassifier,
    get_default_tp_classifier_config,
    init_tp,
    scale_init,
    train_step_tp,
)

Additionally, we recreate the config, mesh, and batch from Part 4.1 to use the same task as before.

[5]:
# Load the default configuration for the classifier.
config = get_default_tp_classifier_config()
# Initialize multi-device mesh.
device_array = np.array(jax.devices()).reshape(-1, config.model_axis_size)
mesh = Mesh(device_array, (config.data_axis_name, config.model_axis_name))
# Batch for random classification task.
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)
2024-03-07 10:48:41.885331: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:273] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tensor Parallelism with Compute-Communication Overlap

In the previous notebook on pipeline parallelism, we discussed how one can overlap communication with computation by sending sub-results as soon as they become available. We can apply the same principle to tensor parallelism. At the start of the linear layer, each device has a subset of the input features. We can directly start computing the output with respect to these features, which corresponds to \(A_{i,i}x_{i}\) in our original notation. While we do that, we can already communicate the features to the next device, such that we overlap communication with computation. As soon as the next device receives the features, it can start computing the output with respect to these features, and continue the process. This way, we save the time of waiting for the communication to finish before we can start computing the output, and we improve the efficiency of the model, closer to the theoretical maximum.

Before we can implement this strategy, we need to have a way of performing the communication efficiently and asynchronously with respect to the computation. In JAX, we can do this with jax.lax.ppermute, which is a parallel permutation operation and we have seen in previous notebooks before. It allows us to send an array to the next device in a ring topology, and receive an array from the previous device. As long as we do not try to access the array before it has been fully communicated, we can continue with the computation and overlap communication with computation. We will use this operation to implement an asynchronous version of our gather and scatter-sum operations.

Async Gather

We start with the gather strategy. Each device holds a subset of the input features, and we want to communicate the input features asynchronously to all devices. We start by performing a jax.lax.ppermute to send the current features to the next device in our ring topology. This means device 0 sends the features to device 1, device 1 sends the features to device 2, and so on, until the last device sends its features to device 0. However, instead of concatenating the communicated features with the current features, as we would get with jax.lax.all_gather, we collect the communicated arrays in a list. This way, the features present on the device and the communicated features become two independent arrays, on which we can perform independent operations. For instance, we could already start computing the output with respect to the current features, while waiting for the communication to finish.

At this point, each device has its own features and those of its neighbor. However, we want all features to be gathered on all devices. We achieve this by communicating the newly communicated features to the next device as soon as they become available. For example, device 0 sends the communicated features from device 3 to device 1, device 1 sends the communicated features from device 0 to device 2, and so on. By continuining this pattern, we can ensure that all devices have all features by using the minimal amount of communication (each device only sends and receives \(N-1\) features, where \(N\) is the number of devices).

The communication pattern is visualized below. The first column shows the initial features on each device, and each block represents the list of features per device. The second column shows the communicated features after a single jax.lax.ppermute, and so on. The arrows indicate the communication pattern.

aec471016b384982b467d19785290bdd

We note that in contrast to the all-gather operation, each device ends up with a different order of the features. This is not a problem for our purposes, since our final operation of performing a matrix multiplication with the weight matrix is independent of the order of the features. We will sum the outputs of each feature and learn the weight matrix from scratch, thus being able to handle the different order of the features. Still, one should keep in mind that this might lead to a different weight matrix order than in the all-gather operation, and converting between the two operations requires a permutation of the weight matrix.

Let’s now implement the async gather strategy below. In implementation, we can decide which direction of the ring we communicate along (device 0 to device 1, or device 1 to device 0). More on this later. We further write the implementation general enough such that it also supports PyTree’s that we need to communicate across devices. This becomes helpful when we have a more complex model, and the module’s input or output is a PyTree.

[6]:
def async_gather(x: PyTree, axis_name: str, shift_up: bool = True) -> List[PyTree]:
    """All gather using ring permutation.

    Args:
        x: The input to gather.
        axis_name: The axis name to gather along.
        shift_up: Whether to shift up (device 0 send to device 1) or down (device 1 send to device 0).

    Returns:
        List of gathered inputs.
    """
    tp_size = jax.lax.psum(1, axis_name)
    # Determine communication permutation.
    if shift_up:
        shift_perm = [(j, (j + 1) % tp_size) for j in range(tp_size)]
    else:
        shift_perm = [(j, (j - 1) % tp_size) for j in range(tp_size)]
    ps = [x]
    p = x
    # Perform all-gather using ring permutation.
    for _ in range(1, tp_size):
        p = jax.lax.ppermute(p, axis_name, perm=shift_perm)
        ps.append(p)
    return ps

The output is now a list of arrays, on which we can perform independent, asynchronous operations. We can schedule each operation as the features become available, and are only blocked once we want to access an array that has not finished communicating yet. As long as the individual operations are sufficiently making use of the device, we can overlap communication with computation.

Example

Let’s make a small example to illustrate the async gather strategy. We will use a simple feature array of shape (2, 4, 1) (data axis, model axis, feature axis), and split it over our 8 devices.

[7]:
x = np.arange(jax.local_device_count(), dtype=jnp.float32)
x = np.reshape(x, (-1, config.model_axis_size, 1))
x
[7]:
array([[[0.],
        [1.],
        [2.],
        [3.]],

       [[4.],
        [5.],
        [6.],
        [7.]]], dtype=float32)

We now call the async gather function with the feature array over the model axis. We will use the default communication direction, which is from device 0 to device 1.

[8]:
gather_model_fn = shard_map(
    functools.partial(async_gather, axis_name=config.model_axis_name),
    mesh=mesh,
    in_specs=P(config.data_axis_name, config.model_axis_name),
    out_specs=P(config.data_axis_name, config.model_axis_name),
)
x_gather_model = gather_model_fn(x)
x_gather_model = jax.device_get(x_gather_model)
for idx in range(jax.local_device_count()):
    print(
        f"Device {idx}: {[feat.reshape(-1, feat.shape[-1])[idx].item() for feat in x_gather_model]}"
    )
Device 0: [0.0, 3.0, 2.0, 1.0]
Device 1: [1.0, 0.0, 3.0, 2.0]
Device 2: [2.0, 1.0, 0.0, 3.0]
Device 3: [3.0, 2.0, 1.0, 0.0]
Device 4: [4.0, 7.0, 6.0, 5.0]
Device 5: [5.0, 4.0, 7.0, 6.0]
Device 6: [6.0, 5.0, 4.0, 7.0]
Device 7: [7.0, 6.0, 5.0, 4.0]

The output is a list of arrays, where each array is the communicated features from the previous device. Device 0 to device 3 show the same communication pattern as we have visualized above. Device 0 sends its first features to device 1, and receives the first features from device 3. If we had flipped the direction, device 0 would have 1.0 as its first feature. Meanwhile, device 4 to device 7 are an independent model group, since they are stacked over the data axis. Thus, device 4 communicates its first features to device 5, and receives the first features from device 7.

Bidrectional Communication

Our implementation of the async gather strategy allows us to communicate the features in both directions, which we can exploit further. For instance, in TPU superpods, TPUs are connected in a 2D/3D torus mesh, such that each TPU has an interconnect to all its neighbors. Therefore, we can communicate the features in both directions at the same time to maximize our usage of bandwidth (more details here). Similarly, NVLink connections between GPUs allow for bidirectional communication, which we can exploit.

At each step, we perform a jax.lax.ppermute in both directions, and collect the communicated features. Note that we need to keep track of the features we send in both directions separately, since in the second step, device 1 will send device 0’s features to device 2 and not accidentally back to device 0. Further, we can ensure the same order in the list as the non-bidirectional version by keeping two separate lists and merge them afterwards. The communication pattern is visualized below.

1b574cbbc3844367a51ef15539e986eb

We can now implement the bidirectional async gather strategy below, with the same principles as the non-bidirectional version.

[9]:
def async_gather_bidirectional(
    x: jax.Array, axis_name: str, shift_up: bool = True
) -> List[jax.Array]:
    """All gather using ring permutation with bidirectional communication.

    Args:
        x: The input to gather.
        axis_name: The axis name to gather along.
        shift_up: Whether to return the order of tensors that complies with the unidrectional version of shift up (device 0 send to device 1) or down (device 1 send to device 0).

    Returns:
        List of gathered inputs.
    """
    tp_size = jax.lax.psum(1, axis_name)
    shift_up_perm = [(j, (j + 1) % tp_size) for j in range(tp_size)]
    shift_down_perm = [(j, (j - 1) % tp_size) for j in range(tp_size)]
    ps_up = []
    ps_down = []
    p_up = x
    p_down = x
    for i in range(1, tp_size):
        if i % 2 == 0:
            p_down = jax.lax.ppermute(p_down, axis_name=axis_name, perm=shift_down_perm)
            ps_down.append(p_down)
        else:
            p_up = jax.lax.ppermute(p_up, axis_name=axis_name, perm=shift_up_perm)
            ps_up.append(p_up)
    # Combine communication in both directions.
    # This list will have the same order as the unidirectional up version.
    if shift_up:
        ps = [x] + ps_up + ps_down[::-1]
    else:
        ps = [x] + ps_down + ps_up[::-1]
    return ps

As before, we can make a small example to illustrate the bidirectional async gather strategy. We will use the same feature array as before.

[10]:
gather_bidir_model_fn = shard_map(
    functools.partial(async_gather_bidirectional, axis_name=config.model_axis_name),
    mesh=mesh,
    in_specs=P(config.data_axis_name, config.model_axis_name),
    out_specs=P(config.data_axis_name, config.model_axis_name),
)
x_gather_model = gather_bidir_model_fn(x)
x_gather_model = jax.device_get(x_gather_model)
for idx in range(jax.local_device_count()):
    print(
        f"Device {idx}: {[feat.reshape(-1, feat.shape[-1])[idx].item() for feat in x_gather_model]}"
    )
Device 0: [0.0, 3.0, 2.0, 1.0]
Device 1: [1.0, 0.0, 3.0, 2.0]
Device 2: [2.0, 1.0, 0.0, 3.0]
Device 3: [3.0, 2.0, 1.0, 0.0]
Device 4: [4.0, 7.0, 6.0, 5.0]
Device 5: [5.0, 4.0, 7.0, 6.0]
Device 6: [6.0, 5.0, 4.0, 7.0]
Device 7: [7.0, 6.0, 5.0, 4.0]

The result is identical to the unidirectional version, but we have communicated the features in both directions at the same time. This can be useful to maximize the usage of the interconnect bandwidth, and is especially useful in TPU nodes.

If you look carefully at the diagram, you spot another minor inefficiency. If we have an even number of devices, the last communication cycle will be unidirectional again since we require an uneven amount of communication cycles. An alternative bidirectional communication strategy that overcomes this inefficiency is to split the features over the hidden dimension, and communicate half in one direction and half in the other direction. This way, we can ensure that we always communicate in both directions, and may speed up the latency of the first feature to be communicated, since the features are smaller. However, this strategy requires more communication cycles and may require operations on smaller arrays, which may have a lower utilization of the devices depending on the feature and operation size. Additionally, it gives a strictly different list structure of the features than the unidirectional version. We implement this strategy below.

[11]:
def async_gather_split(x: jax.Array, axis_name: str) -> List[jax.Array]:
    """All gather using ring permutation with features split for bidirectional communication.

    Args:
        x: The input to gather.
        axis_name: The axis name to gather along.

    Returns:
        List of gathered inputs. Length is 2 * axis size - 1.
    """
    x1, x2 = jax.tree_map(lambda x: jnp.split(x, 2, axis=-1), x)
    return async_gather(x1, axis_name, shift_up=True) + async_gather(x2, axis_name, shift_up=False)

We can make a small example again, where we double the feature dimension of x to allow for splitting over the hidden dimension.

[12]:
gather_split_model_fn = shard_map(
    functools.partial(async_gather_split, axis_name=config.model_axis_name),
    mesh=mesh,
    in_specs=P(config.data_axis_name, config.model_axis_name),
    out_specs=P(config.data_axis_name, config.model_axis_name),
)
x_double = np.concatenate([x, x + 0.5], axis=-1)
x_gather_model = gather_split_model_fn(x_double)
x_gather_model = jax.device_get(x_gather_model)
for idx in range(jax.local_device_count()):
    print(
        f"Device {idx}: {[feat.reshape(-1, feat.shape[-1])[idx].item() for feat in x_gather_model]}"
    )
Device 0: [0.0, 3.0, 2.0, 1.0, 0.5, 1.5, 2.5, 3.5]
Device 1: [1.0, 0.0, 3.0, 2.0, 1.5, 2.5, 3.5, 0.5]
Device 2: [2.0, 1.0, 0.0, 3.0, 2.5, 3.5, 0.5, 1.5]
Device 3: [3.0, 2.0, 1.0, 0.0, 3.5, 0.5, 1.5, 2.5]
Device 4: [4.0, 7.0, 6.0, 5.0, 4.5, 5.5, 6.5, 7.5]
Device 5: [5.0, 4.0, 7.0, 6.0, 5.5, 6.5, 7.5, 4.5]
Device 6: [6.0, 5.0, 4.0, 7.0, 6.5, 7.5, 4.5, 5.5]
Device 7: [7.0, 6.0, 5.0, 4.0, 7.5, 4.5, 5.5, 6.5]

As one can see, the communication pattern is different than the unidirectional version, and we have communicated the features in both directions at the same time. Due to the different setup and need for more operations, we will focus on the other bidirectional communication strategy in the following.

Async Scatter

We now turn to the scatter implementation. In scatter, we have the opposite situation than in gather: all inputs are already available at the start of the operation, and instead, we want to communicate the output to all devices. Thus, we want to start communicating as soon as we have computed a part of the output needed on another device. The asynchronous scatter strategy is visualized below.

8bf1f6c0d856401690bae515e8793972

As input, we have the full output features on each device. For clarity, we denote the arrays as \(a_0,...,a_3\) for device 0 and so on, where \(a_i\) corresponds to \(y^{(0)}_i\) in our earlier notation. Note that not all arrays need to be ready at this point, and mainly start with a list of arrays to indicate the computation graph to the compiler. In eager mode, this can correspond to the setup where the CPU offloaded the computation of the array to the GPU, but can already continue with the next operation until the values of the array are needed.

In the first step, we communicate the first outputs of all devices in a round robin fashion. This communication can be performed as soon as \(a_0\), \(b_0\), \(c_0\), and \(d_0\) are available (highlighted in red), and will be overlapped with the computation of \(a_1\), \(b_1\), \(c_1\), and \(d_1\). The communicated arrays will then be added to the output arrays \(a_1\), \(b_1\), \(c_1\), and \(d_1\) as soon as they become available, and we start the next round of communication. This round overlaps with the computation of \(a_2\), \(b_2\), \(c_2\), and \(d_2\), and so on. This way, we can overlap communication with computation and improve the efficiency of the model. The final output on each device will be the sum of some output part of all devices, which follows the scatter pattern. However, in comparison to the non-async scatter, the order of the output parts will be different: device 0 has the sum \(a_3+b_0+c_1+d_2\), while in the non-async scatter, it would have been \(a_0+b_0+c_0+d_0\). As for the async gather operation, this is usually not a problem, since the order of the output parts is not important for the learned linear layer.

We can now implement the async scatter strategy below. Given a list of arrays, we will communicate the arrays in a round robin fashion, and add the communicated arrays to the output arrays as soon as they become available. The output is the sum of all communicated arrays and the last output array.

[13]:
def async_scatter(xs: Sequence[PyTree], axis_name: str, shift_up: bool = True) -> PyTree:
    """Scatter sum using ring permutation.

    Args:
        xs: The inputs to scatter sum. The length of the list should match the size of the axis.
        axis_name: The axis name to scatter sum along.
        shift_up: Whether to shift up (device 0 send to device 1) or down (device 1 send to device 0).

    Returns:
        The scatter summed output.
    """
    tp_size = jax.lax.psum(1, axis_name)
    assert (
        len(xs) == tp_size
    ), f"Number of shards needs to match axis size, but got {len(xs)} with {axis_name} axis size {tp_size}."
    if shift_up:
        shift_perm = [(j, (j + 1) % tp_size) for j in range(tp_size)]
    else:
        shift_perm = [(j, (j - 1) % tp_size) for j in range(tp_size)]
    y = xs[0]
    for x in xs[1:]:
        y = jax.lax.ppermute(y, axis_name, perm=shift_perm)
        y = jax.tree_map(jnp.add, y, x)
    return y

Example

Let’s make a small example to illustrate the async scatter strategy. We will use a simple feature array of shape (2, 4, 4) (data axis, model axis, feature axis), and split it over our 8 devices.

[14]:
np_rng = np.random.default_rng(42)
x = np_rng.integers(
    low=0,
    high=10,
    size=(
        jax.local_device_count() // config.model_axis_size,
        config.model_axis_size,
        config.model_axis_size,
    ),
)
pprint(x)
array([[[0, 7, 6, 4],
        [4, 8, 0, 6],
        [2, 0, 5, 9],
        [7, 7, 7, 7]],

       [[5, 1, 8, 4],
        [5, 3, 1, 9],
        [7, 6, 4, 8],
        [5, 4, 4, 2]]])

In this example, we have device 0 with \(a_0=0, a_1=0, a_2=2, a_3=7\), and so on. We call the async scatter function with the feature array over the model axis.

[15]:
scatter_model_fn = shard_map(
    lambda x: async_scatter(x, axis_name=config.model_axis_name),
    mesh=mesh,
    in_specs=P(config.data_axis_name, config.model_axis_name),
    out_specs=P(config.data_axis_name, config.model_axis_name),
)
xs = np.split(x, x.shape[-1], axis=-1)
y_scatter_model = scatter_model_fn(xs)
for idx in range(jax.local_device_count()):
    print(f"Device {idx}: {y_scatter_model.reshape(-1, y_scatter_model.shape[-1])[idx]}")
Device 0: [15]
Device 1: [21]
Device 2: [23]
Device 3: [20]
Device 4: [19]
Device 5: [28]
Device 6: [15]
Device 7: [14]

To check the result, we can do the operation by hand using the figure above, and get:

  • Output on device 0: \(a_3 + b_0 + c_1 + d_2 = 4 + 4 + 0 + 7 = 15\)

  • Output on device 1: \(a_2 + b_3 + c_0 + d_1 = 6 + 6 + 2 + 7 = 21\)

  • Output on device 2: \(a_1 + b_2 + c_3 + d_0 = 7 + 0 + 9 + 7 = 23\)

  • Output on device 3: \(a_0 + b_1 + c_2 + d_3 = 0 + 8 + 5 + 7 = 20\)

The result is identical to the expected output, suggesting we have successfully implemented the async scatter strategy.

Bidirectional Communication

Similar to the async gather strategy, we can also communicate the features in both directions at the same time by splitting the features over the hidden dimension. This way, we can ensure that we always communicate in both directions, and may improve efficiency. We implement this strategy below.

[16]:
def async_scatter_split(xs: Sequence[PyTree], axis_name: str) -> PyTree:
    """Scatter sum using ring permutation with features split for bidirectional communication.

    Args:
        xs: The inputs to scatter sum. The length of the list should match the size of the axis.
        axis_name: The axis name to scatter sum along.

    Returns:
        The scatter summed output.
    """

    def _split(x: PyTree) -> Tuple[PyTree, PyTree]:
        return (
            jax.tree_map(lambda x: x[..., : x.shape[-1] // 2], x),
            jax.tree_map(lambda x: x[..., x.shape[-1] // 2 :], x),
        )

    tp_size = jax.lax.psum(1, axis_name)
    assert (
        len(xs) == tp_size
    ), f"Number of shards needs to match axis size, but got {len(xs)} with {axis_name} axis size {tp_size}."
    shift_perm_up = [(j, (j + 1) % tp_size) for j in range(tp_size)]
    shift_perm_down = [(j, (j - 1) % tp_size) for j in range(tp_size)]
    y_up, y_down = _split(xs[0])
    for x in xs[1:]:
        y_up = jax.lax.ppermute(y_up, axis_name, perm=shift_perm_up)
        y_down = jax.lax.ppermute(y_down, axis_name, perm=shift_perm_down)
        x_up, x_down = _split(x)
        y_up = jax.tree_map(jnp.add, y_up, x_up)
        y_down = jax.tree_map(jnp.add, y_down, x_down)
    return jax.tree_map(lambda y1, y2: jnp.concatenate([y1, y2], axis=-1), y_up, y_down)

The first half of the features are processed in the same way as in the unidirectional version with shift_up=True, and the second half of the features are processed in the same way as in the unidirectional version with shift_up=False. We can verify this by repeating our example array over the last axis and check the output.

[17]:
scatter_model_fn = shard_map(
    lambda x: async_scatter_split(x, axis_name=config.model_axis_name),
    mesh=mesh,
    in_specs=P(config.data_axis_name, config.model_axis_name),
    out_specs=P(config.data_axis_name, config.model_axis_name),
)
x_double = np.repeat(x, 2, axis=-1)
xs = np.split(x_double, x.shape[-1], axis=-1)
y_scatter_model = scatter_model_fn(xs)
for idx in range(jax.local_device_count()):
    print(f"Device {idx}: {y_scatter_model.reshape(-1, y_scatter_model.shape[-1])[idx]}")
Device 0: [15 11]
Device 1: [21 18]
Device 2: [23 27]
Device 3: [20 23]
Device 4: [19 16]
Device 5: [28 22]
Device 6: [15 18]
Device 7: [14 20]

The first feature dimension is indeed the same as in the unidirectional version above. We can verify the second feature dimension by hand or running the previous example with shift_up=False. The result is identical to the expected output, suggesting we have successfully implemented the bidirectional async scatter strategy. Due to the different setup and need for more operations, we will focus on the unidirectional scatter version in the following, but the bidirectional version can be more efficient in some cases.

Asynchronous Linear Layer

We can now implement the async gather and scatter strategies in a linear layer, as for example used in the ViT-22b model. Both follow very closely the asynchronous communication functions we implemented above, just with added computation.

In the gather strategy, we start with computing the output with respect to the current features (i.e. \(A_{i,i}x_i\) on device \(i\)), since they are already available on each device. At the same time, we communicate the features to the next device. Once the features are communicated and we finished the computation, we can start computing the output with respect to the communicated features, and continue the process. All outputs are summed to obtain the final output. This process is visualized below (figure credit: Dehghani et al., 2023).

5c9742a881744a00b1943bb83b6025df

In the scatter strategy, we start with computing the output that will require the longest path of communication. Once computed, we send the output to the next device, and calculate the next output. We then sum the output with the communicated features, and continue the process. As the final output, we compute the features for the current device. This process is visualized below (figure credit: Dehghani et al., 2023).

b84b480b19384a86b9f095d771afa2a8

We extend our previous TPDense class to implement the asynchronous version. Instead of applying the dense layer only once, we apply it in a loop over the sub-features in the two strategies. We let the compiler figure out when to optimally schedule the individual communication and computation operations, which is expected to be close to our computation diagrams above. By splitting the dense layer into multiple smaller layers, each weight matrix will be of size \(d_y / \text{num}\_\text{devices} \times d_x / \text{num}\_\text{devices}\), such that we may need to adjust the kernel differently (e.g. fan-in adjustment in both scatter and gather). Further, we ensure that for each final output feature, we only use a single bias parameter to remain consistent with the non-parallelised models.

[18]:
class TPAsyncDense(nn.Module):
    """Tensor-Parallel Dense Layer with Asynchronous Communication.

    This layer can be used to perform a dense layer with Tensor Parallelism support, and overlaps communication with computation whenever possible.

    Attributes:
        dense_fn: Constructor function of the dense layer to use. Needs to support the keyword argument `kernel_init`.
        model_axis_name: The name of the model axis.
        tp_mode: The Tensor Parallelism mode to use. Can be "scatter", "gather", or "none".
        kernel_init: The initializer to use for the kernel of the dense layer.
        kernel_init_adjustment: The adjustment factor to use for the kernel initializer.
        dense_name: The name of the dense layer module.
        use_bidirectional_gather: Whether to use bidirectional or unidirectional gather over the device ring for communication.
    """

    dense_fn: Any
    model_axis_name: str
    tp_mode: Literal["scatter", "gather", "none"] = "none"
    kernel_init: Callable = nn.initializers.lecun_normal()
    kernel_init_adjustment: float = 1.0
    dense_name: str = "module"
    use_bidirectional_gather: bool = True

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        tp_size = jax.lax.psum(1, self.model_axis_name)
        tp_mode = self.tp_mode if tp_size > 1 else "none"

        dense_fn = functools.partial(
            ModelParallelismWrapper,
            model_axis_name=self.model_axis_name,
            module_fn=functools.partial(
                self.dense_fn,
                kernel_init=scale_init(self.kernel_init, self.kernel_init_adjustment),
            ),
            name=self.dense_name,
        )

        if tp_mode == "none":
            y = self.dense_fn(kernel_init=self.kernel_init, name="shard_0")(x)
        elif tp_mode == "gather":
            # Async gathering of all inputs.
            async_op = (
                async_gather_bidirectional if self.use_bidirectional_gather else async_gather
            )
            xs = async_op(x, axis_name=self.model_axis_name)
            # Compute output per input (scheduled as communication makes inputs available).
            ys = [
                dense_fn(
                    module_kwargs={
                        "use_bias": (i == 0)
                    },  # Only need a single per final output feature.
                    name=f"shard_{i}",
                )(x)
                for i, x in enumerate(xs)
            ]
            # Final sum of all outputs.
            y = jax.tree_map(lambda *args: sum(args), *ys)
        elif tp_mode == "scatter":
            # Calculate all outputs per device.
            ys = [
                dense_fn(
                    module_kwargs={
                        "use_bias": (i == 0)
                    },  # Only need a single per final output feature.
                    name=f"shard_{i}",
                )(x)
                for i in range(tp_size)
            ]
            # Async scatter sum of all outputs (communication already starts after first output is ready).
            y = async_scatter(ys, axis_name=self.model_axis_name)
        else:
            raise ValueError(f"Unknown Tensor Parallel mode: {tp_mode}")
        return y

Now, let’s use these asynchronous linear layers to improve the efficiency of our MLP blocks. We will implement the same MLP block as before, but replace the TPDense layers with TPAsyncDense layers. We will use the gather strategy for the first linear layer, and the scatter strategy for the second linear layer. For the input layer, we also need to adjust the initialization by scaling the values by \(\sqrt{1/\text{num}\_\text{devices}}\), since the input to each layer will be \(1/\text{num}\_\text{devices}\) of the full feature size.

While the splitting of the dense layer into multiple smaller layers works without problems, we also need to apply other layers in the same way. For instance, the activation function can be applied independently on each input, such that we do not need to adjust for it. However, the normalization layer commonly contains statistics that are computed over the full feature size, and we cannot do it anymore within the input layer (in the non-async implementation, the gather strategy allowed for it). Instead, we first apply the normalization layer and compute the statistics across devices. Luckily, in Flax, this is already supported by passing an axis_name to the normalization layer (see e.g. the docs for RMSNorm). We only need to wrap it in a model parallelism wrapper since each device will have scaling parameters for its respective features. This gives us the same result as if we had computed the statistics over the full feature size on a single device, and we can continue with the rest of the operations. Let’s implement this norm class below.

[19]:
class TPNorm(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        x = ModelParallelismWrapper(
            model_axis_name=self.config.model_axis_name,
            module_fn=functools.partial(
                nn.RMSNorm,
                dtype=self.config.dtype,
                axis_name=self.config.model_axis_name,
            ),
            name="norm",
        )(x)
        return x

We use this normalization layer in the TPAsyncMLPBlock class below to define the whole block.

[20]:
class TPAsyncMLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        tp_size = jax.lax.psum(1, self.config.model_axis_name)
        input_features = x.shape[-1]
        # Normalize across devices before the input layer.
        x = TPNorm(config=self.config, name="pre_norm")(x)
        # Input dense layer with async gather.
        x = TPAsyncDense(
            dense_fn=functools.partial(
                MLPBlockInput,
                config=self.config,
                features=self.config.hidden_size * self.config.mlp_expansion // tp_size,
                use_norm=False,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="gather",
            kernel_init_adjustment=tp_size**-0.5,
            name="input",
        )(x)
        # Output dense layer with async scatter.
        x = TPAsyncDense(
            dense_fn=functools.partial(
                MLPBlockOutput,
                config=self.config,
                features=input_features,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="scatter",
            kernel_init_adjustment=tp_size**-0.5,
            name="output",
        )(x)
        return x

Initialization

The rest of the model is identical to the non-async version. Let’s first create the classifier model with the new async MLP block. Note that the input and output layer do not require the async strategy. The input layer has already features gathered, such that no communication is needed. For the output, we require all outputs to be available on all devices (or at least device), which we ensure via jax.lax.psum and requires blocking communications.

[21]:
model_tp_async = TPClassifier(config.model, block_class=TPAsyncMLPBlock)
optimizer = optax.adamw(learning_rate=config.optimizer.learning_rate)

We reuse the same initialization function with the new model.

[22]:
init_tp_async_fn = shard_map(
    functools.partial(init_tp, model=model_tp_async, optimizer=optimizer),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_tp_async_shapes = jax.eval_shape(init_tp_async_fn, model_init_rng, batch.inputs)
state_tp_async_specs = nn.get_partition_spec(state_tp_async_shapes)

Let’s inspect how the async layers have impacted the parameter sharding in the MLP block.

[23]:
pprint(state_tp_async_specs.params)
{'input_layer': {'module': {'sharded': {'bias': PartitionSpec('model', None),
                                        'kernel': PartitionSpec('model', None, None)}}},
 'mlp': {'block': {'input': {'shard_0': {'sharded': {'dense': {'bias': PartitionSpec(None, 'model', None),
                                                               'kernel': PartitionSpec(None, 'model', None, None)}}},
                             'shard_1': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}},
                             'shard_2': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}},
                             'shard_3': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}}},
                   'output': {'shard_0': {'sharded': {'dense': {'bias': PartitionSpec(None, 'model', None),
                                                                'kernel': PartitionSpec(None, 'model', None, None)}}},
                              'shard_1': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}},
                              'shard_2': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}},
                              'shard_3': {'sharded': {'dense': {'kernel': PartitionSpec(None, 'model', None, None)}}}},
                   'pre_norm': {'norm': {'sharded': {'scale': PartitionSpec(None, 'model', None)}}}}},
 'output_layer': {'module': {'sharded': {'bias': PartitionSpec('model', None),
                                         'kernel': PartitionSpec('model', None, None)}}}}

Each input and output layer of the MLP block contains several sub-modules now, one per smaller dense layer (equivalent to number of devices). Each has a kernel with the same sharding as before, but only the first layer has a bias term. The normalization layer is now outside of the input layer, with the same sharding as before. The input and output layers of the whole model did not change.

We can now continue with the initialization.

[24]:
init_tp_async_fn = jax.jit(
    shard_map(
        functools.partial(init_tp, model=model_tp_async, optimizer=optimizer),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_tp_async_specs,
        check_rep=False,
    ),
)
state_tp_async = init_tp_async_fn(model_init_rng, batch.inputs)

Let’s also inspect the parameter shapes to ensure the initialization worked as expected.

[25]:
print("TP Parameters - MLP Layers Pre-Norm")
pprint(
    jax.tree_map(
        lambda x: x.shape, state_tp_async.params["mlp"]["block"]["pre_norm"]["norm"]["sharded"]
    )
)
print()
print("TP Parameters - MLP Layers Input")
pprint(
    jax.tree_map(
        lambda x: x.shape, state_tp_async.params["mlp"]["block"]["input"]["shard_0"]["sharded"]
    )
)
print()
print("TP Parameters - MLP Layers Output")
pprint(
    jax.tree_map(
        lambda x: x.shape, state_tp_async.params["mlp"]["block"]["output"]["shard_0"]["sharded"]
    )
)
TP Parameters - MLP Layers Pre-Norm
{'scale': Partitioned(value=(3, 4, 128),
                      names=(None, 'model', None),
                      mesh=None)}

TP Parameters - MLP Layers Input
{'dense': {'bias': Partitioned(value=(3, 4, 128),
                               names=(None, 'model', None),
                               mesh=None),
           'kernel': Partitioned(value=(3, 4, 128, 128),
                                 names=(None, 'model', None, None),
                                 mesh=None)}}

TP Parameters - MLP Layers Output
{'dense': {'bias': Partitioned(value=(3, 4, 128),
                               names=(None, 'model', None),
                               mesh=None),
           'kernel': Partitioned(value=(3, 4, 128, 128),
                                 names=(None, 'model', None, None),
                                 mesh=None)}}

Each dense layer has now a smaller kernel (\(512 / 4 = 128\)), which is why we needed the adjusted initialization schemes. The scale parameter of the pre-norm also has the same sharding, such that the parameters are consistent with the expected shapes.

Training

We can now train the model with the async MLP block. The training loop is identical to the non-async version, and we expect the model to learn the task with high accuracy.

[26]:
train_step_tp_async_fn = jax.jit(
    shard_map(
        functools.partial(train_step_tp, config=config),
        mesh,
        in_specs=(state_tp_async_specs, P(), P(config.data_axis_name)),
        out_specs=(state_tp_async_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
state_shapes, metric_shapes = jax.eval_shape(
    train_step_tp_async_fn,
    state_tp_async,
    None,
    batch,
)
metrics_tp_async = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_tp_async, metrics_tp_async = train_step_tp_async_fn(state_tp_async, metrics_tp_async, batch)
/home/plippe/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]), ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]), ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"

We train the model again for 15 steps and print the final loss and accuracy.

[27]:
for _ in range(15):
    state_tp_async, metrics_tp_async = train_step_tp_async_fn(
        state_tp_async, metrics_tp_async, batch
    )
final_metrics_tp_async = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_tp_async, final_metrics_tp_async = train_step_tp_async_fn(
    state_tp_async, final_metrics_tp_async, batch
)
print_metrics(final_metrics_tp_async, title="Final Metrics - Tensor Parallelism Async")
 Final Metrics - Tensor Parallelism Async
accuracy: 1.000000
loss: 0.000022

As we expected, the model is able to learn the task with high accuracy. We have successfully implemented the async gather and scatter strategies in our linear layer, and improved the efficiency of our model. We can now continue to the next section, where we discuss the implementation of a full transformer model with tensor parallelism and fully-sharded data parallelism.

Intermediate Summary

In this notebook, we discussed the principles of tensor parallelism with compute-communication overlap. We implemented the asynchronous communication patterns of gather and scatter, and applied them to a linear layer distributed over multiple devices. We then used these layers to implement an asynchronous MLP block, and trained a model with it. In the next notebook, we will discuss how to implement tensor parallelism in a transformer model. Furthermore, we will profile the model to show the efficiency of the async communication patterns in such models.

References and Resources

[Shoeybi et al., 2019] Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J. and Catanzaro, B., 2019. Megatron-lm: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053. Paper link

[Wang and Komatsuzaki, 2021] Wang, B., and Komatsuzaki, A., 2021. Mesh transformer jax. GitHub link

[Xu et al., 2021] Xu, Y., Lee, H., Chen, D., Hechtman, B., Huang, Y., Joshi, R., Krikun, M., Lepikhin, D., Ly, A., Maggioni, M. and Pang, R., 2021. GSPMD: general and scalable parallelization for ML computation graphs. arXiv preprint arXiv:2105.04663. Paper link

[Dehghani et al., 2022] Dehghani, M., Gritsenko, A., Arnab, A., Minderer, M. and Tay, Y., 2022. Scenic: A JAX library for computer vision research and beyond. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 21393-21398). Paper link

[Yoo et al., 2022] Yoo, J., Perlin, K., Kamalakara, S.R. and Araújo, J.G., 2022. Scalable training of language models using JAX pjit and TPUv4. arXiv preprint arXiv:2204.06514. Paper link

[Chowdhery et al., 2023] Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H.W., Sutton, C., Gehrmann, S., Schuh, P., et al., 2023. Palm: Scaling language modeling with pathways. Journal of Machine Learning Research, 24(240), pp.1-113. Paper link

[Anil et al., 2023] Anil, R., Dai, A.M., Firat, O., Johnson, M., Lepikhin, D., Passos, A., Shakeri, S., Taropa, E., Bailey, P., Chen, Z. and Chu, E., 2023. Palm 2 technical report. arXiv preprint arXiv:2305.10403. Paper link

[Dehghani et al., 2023] Dehghani, M., Djolonga, J., Mustafa, B., Padlewski, P., Heek, J., Gilmer, J., Steiner, A.P., Caron, M., Geirhos, R., Alabdulmohsin, I., Jenatton, R., et al., 2023. Scaling vision transformers to 22 billion parameters. In International Conference on Machine Learning (pp. 7480-7512). PMLR. Paper link

[McKinney, 2023] McKinney, A., 2023. A Brief Overview of Parallelism Strategies in Deep Learning. Blog post link

[Huggingface, 2024] Huggingface, 2024. Model Parallelism. Documentation link

[Google, 2024] JAX Team Google, 2024. SPMD multi-device parallelism with shard_map. Notebook link

[OpenAI, 2024] OpenAI, 2024. GPT-4. Technical Report

[Google, 2024] Gemini Team Google Deepmind, 2024. Gemini. Technical Report


Star our repository If you found this tutorial helpful, consider ⭐-ing our repository.
Ask questions For any questions, typos, or bugs that you found, please raise an issue on GitHub.