Part 4.1: Tensor Parallelism

Filled notebook: View filled on Github Open filled In Collab

Author: Phillip Lippe

In this tutorial, we will discuss tensor parallelism, another important parallelism strategy for training large-scale deep learning models. Similar to pipeline parallelism, tensor parallelism is a model parallelism strategy, which means that it focuses on parallelizing the model itself, rather than the data. The key difference between pipeline and tensor parallelism is how they split the model over devices. In pipeline parallelism, the model is split over devices along the sequence of layers (i.e. vertically), while in tensor parallelism, the model is split over devices along the feature dimensions (i.e. horizontally). Each device will then process a different subset of features, and the model’s forward and backward passes will be split over devices accordingly. A short overview of the parallelism strategies is shown below.

3dc43ba0baba4d81803f0899606fbbfe

Tensor parallelism can be applied on a per-module/per-layer basis. This gives more flexibility in how to split the model over devices than pipeline parallelism, and can even handle situations where a single layer is too big to fit on a single device. Furthermore, tensor parallelism does not suffer from the pipeline bubble problem, as all devices can work on the same batch of data at the same time. The key behind making tensor parallelism efficient will be, again, to overlap computation with communication, and to minimize the amount of communication required.

Still, tensor parallelism relies on frequent communication between devices, such that it requires devices with high speed interconnects like TPUs or GPUs with NVLink, and is often restricted to devices within a node. For example, Gemini v1 was trained with model parallelism within a node (TPU superpod), but applies only data parallelism across nodes.

In this tutorial, we will discuss the principles of tensor parallelism, and how to implement it in JAX. We will first start with an implementation on a simple MLP model. In Part 4.2, we discuss techniques from models like the ViT-22b to maximize efficiency of tensor parallelism with compute-communication overlaps. Finally, in Part 4.3, we will discuss how to apply tensor parallelism to the transformer model specifically, and how to combine tensor parallelism with fully-sharded data parallelism.

Prerequisites

First, let’s start with setting up the basic environment and utility functions we have seen from previous notebooks. We download the python scripts of the previous notebooks below. This is only needed when running on Google Colab, and local execution will skip this step automatically.

[1]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "data_parallel.py", "pipeline_parallel.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

As before, we simulate 8 devices on CPU to demonstrate the parallelism without the need for multiple GPUs or TPUs. If you are running on your local machine and have multiple GPUs available, you can comment out the lines below.

[2]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

We now import our standard libraries.

[3]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Literal, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

PyTree = Any
Parameter = jax.Array | nn.Partitioned
Metrics = Dict[str, Tuple[jax.Array, ...]]

We also import the utility functions from the previous notebooks. Our notebook will rely on the ModelParallelismWrapper from the pipeline parallelism notebook. If you are not familiar with it, it is recommended to look at the implementation of this module before continuing.

[4]:
from data_parallel import fold_rng_over_axis, sync_gradients
from pipeline_parallel import ModelParallelismWrapper
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics

Tensor Parallelism for Linear Layers

The key design principle behind tensor parallelism is to split the model over devices along the feature dimensions. For instance, consider a Transformer model with a hidden size of 1024, and we want to split the model over 4 devices. We would then split the hidden size over the devices, such that device 0 will process features 0-255, device 1 will process features 256-511, and so on. However, as we know from basic deep learning principles, the hidden dimensions are rarely independently processed. Thus, we need to design the model layers such that they communicate features or outputs between devices efficiently, whenever it is needed.

As the most basic neural network operations, let’s consider a matrix multiplication as we would do it in an MLP. We can write it as \(Ax=y\), with \(A\in\mathbb{R}^{d_y\times d_x}\) being the weight matrix, \(x\in\mathbb{R}^{d_x\times B}\) the input (batch last for simplicity), and \(y\in\mathbb{R}^{d_y\times B}\) the output. In tensor parallelism, each device will carry a subset of the input dimensions, e.g. \(x\) is split into \(x_0, x_1, x_2, x_3\) across devices. The goal is to end up with the same output \(y\) as if we had computed it on a single device, but again partitioned across devices (\(y_0, y_1, y_2, y_3\)). This is visualized below.

14cb0805e5014afabb07bfd0bcb8a1a6

The question is now how to split \(A\) such that we can compute \(y\) in a distributed manner. There are two main strategies we can follow are communicating the input (gather) or the output (scatter).

In the gather strategy, we communicate the input \(x\) to all devices, such that each device has the full \(x\) (this communication type is called all_gather). Then, we can compute the output \(y\_i\) on each device \(i\) independently by: \(y_i = \sum_{j} A_{i,j} x_j\).

In the scatter strategy, we compute the sub-result of each input \(x_i\) on the output \(y\) independently on each device: \(y^{(i)}_j=A_{i,j}x_{i}\). Afterwards, we communicate the results across devices and sum the needed result on each device: \(y_i = \sum_{j} y^{(j)}_i\). This communication type is called (psum) scatter.

In terms of the weight matrix \(A\), the two strategies differ in that the gather strategy splits the rows of \(A\) across devices, while the scatter strategy splits the columns of \(A\) across devices. We visualize the two strategies below (for simplicity, the communication is not explicitly visualized).

a4f90b7caa214fa7abcf04960a27b668

Which of the two strategies is more efficient depends on the size of the input and output dimensions. In general, we want to communicate as little data as possible, and thus the gather strategy is more efficient if the input dimension is much larger than the output dimension, and vice versa. Since the dimensions will be different for each layer, we will need to decide on a per-layer basis which strategy to use. For example, in an MLP block of a Transformer where we expand the hidden dimension by 4x, we will want to use the gather strategy for the first linear layer, and the scatter strategy for the second linear layer. This way, we avoid communicating the large hidden dimensionality.

Let’s now implement the two strategies in JAX. In the gather strategy, each device will hold \(A_{i,:}\in\mathbb{R}^{d_y/4\times d_x}\) of the weight matrix, and in the scatter strategy, each device will hold \(A_{:,i}\in\mathbb{R}^{d_y\times d_x/4}\) of the weight matrix. This raises a small difficulty during initialization. Many initialization strategies depend on the shape of the full weight matrix, and we need to adjust them to the shape of the split weight matrix. As a simple trick, we will implement a wrapper around the init function that will scale the values by a specified constant. We then leave it up the user to adjust the constant such that the initialization is appropriate for the split weight matrix. For instance, if we use a fan-in initialization (e.g. He initialization), we would scale the initialization by \(\sqrt{1/\text{num}\_\text{devices}}\) for the scatter strategy to adjust for the \(1/\text{num}\_\text{devices}\) smaller input dimension. For the gather strategy, we would not need to scale the initialization, since all devices will process the full input dimension. For more details on network initialization, see our initialization tutorial. As an alternative, we could implement our own initializer functions that directly take into account the split weight matrix dimensions (which may be tedious to support all initializations), or initialize the full weight matrix on each device and then split it. However, the latter would be less efficient and would potentially even fail if the weight matrix is too large to fit on a single device.

[5]:
def scale_init(init_fn: Callable, scale_factor: float = 1.0):
    """Scales the output of the given init function by the given factor.

    Args:
        init_fn: The init function to scale.
        scale_factor: The factor to scale the output of the init function by.

    Returns:
        A new init function that scales the output of the given init function by the given factor.
    """

    def _init_fn(rng, *args, **kwargs):
        return scale_factor * init_fn(rng, *args, **kwargs)

    return _init_fn

We implement the tensor parallelism for the linear layer below in a wrapper module TPDense. It takes as input a constructor dense_fn to create the linear layer. The TPDense module will then split the weight matrix over the devices, and implement the gather and scatter strategies for the forward and backward passes. For some layers, we may need to implement custom communications. For instance, the very first layer of the model may already have the input gather over devices, since we can prefetch the batch from the host to all devices. Similarly, in the last layer of the module, we may not want to scatter the output, but rather gather it to a single device to compute the loss. We will implement these custom communications in the full model later, and for now support them via the keyword skip_communication.

[6]:
class TPDense(nn.Module):
    """Dense layer with Tensor Parallelism support.

    This layer can be used to perform a dense layer with Tensor Parallelism support.

    Attributes:
        dense_fn: Constructor function of the dense layer to use. Needs to support the keyword argument `kernel_init`.
        model_axis_name: The name of the model axis.
        tp_mode: The Tensor Parallelism mode to use. Can be "scatter", "gather", or "none".
        skip_communication: Whether to skip communication in the Tensor Parallelism strategy. Useful for layers with custom communication or where input has been already gathered beforehand.
        kernel_init: The initializer to use for the kernel of the dense layer.
        kernel_init_adjustment: The adjustment factor to use for the kernel initializer.
        dense_name: The name of the dense layer module.
    """

    dense_fn: Any
    model_axis_name: str
    tp_mode: Literal["scatter", "gather", "none"] = "none"
    skip_communication: bool = False
    kernel_init: Callable = nn.initializers.lecun_normal()
    kernel_init_adjustment: float = 1.0
    dense_name: str = "module"

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        tp_size = jax.lax.psum(1, self.model_axis_name)
        tp_mode = self.tp_mode if tp_size > 1 else "none"
        # Wrap the dense layer in a ModelParallelismWrapper to shard the parameters.
        dense_fn = functools.partial(
            ModelParallelismWrapper,
            model_axis_name=self.model_axis_name,
            module_fn=functools.partial(
                self.dense_fn,
                kernel_init=scale_init(self.kernel_init, self.kernel_init_adjustment),
            ),
            name=self.dense_name,
        )

        if tp_mode == "none":
            # Vanilla dense layer.
            x = self.dense_fn(kernel_init=self.kernel_init)(x)
        elif tp_mode == "gather":
            # Gather strategy: communicate all the inputs to all the devices, then perform the dense layer.
            if not self.skip_communication:
                x = jax.lax.all_gather(x, self.model_axis_name, axis=-1, tiled=True)
            x = dense_fn()(x)
        elif tp_mode == "scatter":
            # Scatter strategy: perform the dense layer on each device, then communicate the outputs to all the devices.
            x = dense_fn()(x)
            if not self.skip_communication:
                x = jax.lax.psum_scatter(
                    x, axis_name=self.model_axis_name, scatter_dimension=x.ndim - 1, tiled=True
                )
        else:
            raise ValueError(f"Unknown Tensor Parallel mode: {tp_mode}")
        return x

Note that one small difference we are skipping over for now is the bias term in the scatter strategy. In the current implementation, each device will hold a separate bias term, and we will sum the bias terms across devices in the forward pass. This gives the bias a four times higher learning rate, which may be undesirable. For simplicity, we will ignore this for now since this will not be our final module, but in later modules, we show how this is addressed.

MLP Block

As an example network, we will implement an MLP block of the same form as used in Transformers. It consists of a normalization layer, a linear layer scaling up the hidden dimensionality, a non-linearity, and a linear layer scaling down the hidden dimensionality again. As discussed before, we will use the gather strategy for the first linear layer, and the scatter strategy for the second linear layer. The computation graph per device is visualized below.

7199537131b04f44b3a832da82cd09f3

Here, \(h_0\) are the intermediate features in the MLP (can be of different dimensions than \(x\)), and \(y^0\) the outputs calculated on \(h_0\) alone. The gather and scatter operations are performed at the two ends of the MLP, such that no communication needs to performed within the MLP block, increasing efficiency.

We start with implementing the input layer, which consists of the normalization and the first linear layer. Afterwards, we want to wrap this module in a TPDense module with the gather strategy. As an example, we use the RMSNorm layer which is used in several recent large models, including ViT-22b and Gemma. Compared to LayerNorm, it does not center the input and does not apply a bias parameter, leading to a small speed gain without degrading model performance. More details are given in the paper. Further, for simplicity and following common practice, we do not apply Dropout within the MLP block.

[7]:
class MLPBlockInput(nn.Module):
    config: ConfigDict
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    use_bias: bool = True
    use_norm: bool = True

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        if self.use_norm:
            x = nn.RMSNorm(dtype=self.config.dtype, name="pre_norm")(x)
        x = nn.Dense(
            features=self.features,
            kernel_init=self.kernel_init,
            use_bias=self.use_bias,
            dtype=self.config.dtype,
            name="dense",
        )(x)
        return x

The output layer will consist of the second linear layer and the non-linearity. We will wrap this module in a TPDense module with the scatter strategy. As an example, we use the SiLU non-linearity. Whether we apply the non-linearity in the output layer or input layer is a design choice in this case, since we use the gather strategy for the input layer. However, had we applied the scatter strategy for the input layer, we can only apply the non-linearity in the output layer, since we would have otherwise summed over the outputs of the activation function instead of the raw outputs in the scatter.

[8]:
class MLPBlockOutput(nn.Module):
    config: ConfigDict
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    use_bias: bool = True

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        x = nn.silu(x)
        x = nn.Dense(
            features=self.features,
            kernel_init=self.kernel_init,
            use_bias=self.use_bias,
            dtype=self.config.dtype,
            name="dense",
        )(x)
        return x

We can now combine the two modules in a single MLPBlock module. For the parallelism strategies to work correctly, we need to adjust the features count accordingly. Each device has \(1/\text{num}\_\text{devices}\) of the hidden features, and outputs the full hidden features. As mentioned earlier, we also adjust the initialization of the scatter layer by scaling the initialization by \(\sqrt{1/\text{num}\_\text{devices}}\), since we use a fan-in initialization strategy.

[9]:
class TPMLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        tp_size = jax.lax.psum(1, self.config.model_axis_name)
        input_features = x.shape[-1]
        # Input layer
        x = TPDense(
            dense_fn=functools.partial(
                MLPBlockInput,
                config=self.config,
                features=self.config.hidden_size * self.config.mlp_expansion // tp_size,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="gather",
            name="input",
        )(x)
        # Output layer
        x = TPDense(
            dense_fn=functools.partial(
                MLPBlockOutput,
                config=self.config,
                features=input_features * tp_size,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="scatter",
            kernel_init_adjustment=tp_size**-0.5,  # fan-in with tp_size fewer inputs.
            name="output",
        )(x)
        return x

MLP Classifier

Our example model will consists of a stack of MLP blocks. For this, we write below a simple wrapper around the MLPBlock to stack multiple blocks. For efficient compilation, we use a nn.scan to apply the same MLP block structure in all layers. The carry between the modules is the sharded features over the model axis.

[10]:
class TPMLPLayers(nn.Module):
    config: ConfigDict
    train: bool
    block_class: Callable[..., nn.Module] = TPMLPBlock

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        module = self.block_class(config=self.config, train=self.train, name="block")
        x, _ = nn.scan(
            lambda module, carry, _: (module(carry) + carry, None),
            variable_axes={"params": 0},
            split_rngs={"params": True, "dropout": True},
            length=self.config.num_layers,
            metadata_params={
                "partition_name": None
            },  # We do not need to partition the parameters over the layer axis.
        )(module, x, ())
        return x

Finally, we combine the MLP blocks with an input and output layer. We expect that the input to the model is duplicated over model devices and thus does not need to be gathered anymore. This is likely the best case for the input processing as well, since the batch can already be prefetched to all devices and we may not be able to split the input over model devices equally (e.g. text may be only single integers, so that we cannot split it over feature dimensions). If working with a mesh where the model axis goes across processes, we may want to split the input over model devices on the batch dimension as well, and gather it before applying the model. This ensures all model devices will start with the same input.

The output layer will be a linear layer with the number of classes as output dimensions. We will wrap this layer in a TPDense module with the scatter strategy, but we will not scatter the output. Instead, to compute the loss, a device needs to have the full output features. Hence, we apply a jax.lax.psum to sum the final output over devices. Note that this gives all model devices the same tensor, and thus the same loss. We may want to then only calculate the loss on a single device, and broadcast it back to all devices via the psum operation. For models with large output sizes, this might be inefficient since a single device needs to be able to hold the entire output. For simplicity, we will ignore this for now here, but address it in the transformer model later. Finally, as usual, we convert the output to float32 to avoid numerical issues in the loss computation.

[11]:
class TPClassifier(nn.Module):
    config: ConfigDict
    block_class: Callable[..., nn.Module] = TPMLPBlock

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        tp_size = jax.lax.psum(1, self.config.model_axis_name)
        # Input layer
        x = TPDense(
            dense_fn=functools.partial(
                nn.Dense,
                features=self.config.hidden_size // tp_size,
                dtype=self.config.dtype,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="gather",
            skip_communication=True,  # Input already gathered.
            name="input_layer",
        )(x)
        # Backbone MLP blocks
        x = TPMLPLayers(config=self.config, train=train, name="mlp", block_class=self.block_class)(
            x
        )
        # Output layer
        x = TPDense(
            dense_fn=functools.partial(
                nn.Dense,
                features=self.config.num_classes,
                dtype=self.config.dtype,
            ),
            model_axis_name=self.config.model_axis_name,
            tp_mode="scatter",
            skip_communication=True,  # Manual communication.
            name="output_layer",
            kernel_init_adjustment=tp_size**-0.5,  # fan-in with tp_size fewer inputs.
        )(x)
        x = jax.lax.psum(x, axis_name=self.config.model_axis_name)
        x = x.astype(jnp.float32)
        return x

Initialization

With the model implemented, we can now initialize the model. We start with the config definition, which is similar to previous notebooks. We parallelize the model over 4 devices, and for simplicity, keep the MLP expansion factor at 1. Feel free to experiment with different configurations.

[12]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=10,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        dropout_rate=0.1,
        mlp_expansion=1,
        num_layers=3,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
        model_axis_name="model",
        model_axis_size=4,
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=1,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        model_axis_name=model_config.model_axis_name,
        model_axis_size=model_config.model_axis_size,
        seed=42,
    )
)

The rest of the initialization is identical to the previous notebook on pipeline parallelism. We first create our mesh over data and model.

[13]:
device_array = np.array(jax.devices()).reshape(-1, config.model_axis_size)
mesh = Mesh(device_array, (config.data_axis_name, config.model_axis_name))
2024-03-07 10:48:19.003795: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:273] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We then create the model object and the optimizer. We stick with simple Adam in this example, but feel free to change the optimizer setup.

[14]:
model_tp = TPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

For simplicity, we will train the model on a simple random data classification task. This is mainly to demonstrate the pipeline parallelism, and not to achieve state-of-the-art results. In practice, one would instead create a dataset and dataloader at this point, and setup the data prefetching.

[15]:
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)

The initialization function follows the same principles as in the previous notebook, creating the parameters via model.init and the optimizer parameters in the TrainState.create.

[16]:
def init_tp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

Before we can run the full initialization, we need to identify the partitioning of the parameters. Since we annotated the partitioning of all parameters via nn.Partitioned in the model, we can obtain the partitioning by calling jax.eval_shape on the init function. This will return the state shapes, as well as the nn.Partitioned parameter leafs. From those, we can read out the partitioning using nn.get_partition_spec. For the initial call, we can leave the out_specs of the shard map empty, since we do not create the actual parameters during shape evaluation.

[17]:
init_tp_fn = shard_map(
    functools.partial(init_tp, model=model_tp),
    mesh,
    in_specs=(P(), P(config.data_axis_name)),
    out_specs=P(),
    check_rep=False,
)
state_tp_shapes = jax.eval_shape(init_tp_fn, model_init_rng, batch.inputs)
state_tp_specs = nn.get_partition_spec(state_tp_shapes)

Let’s investigate the partitioning of the parameters.

[18]:
pprint(state_tp_specs.params)
{'input_layer': {'module': {'sharded': {'bias': PartitionSpec('model', None),
                                        'kernel': PartitionSpec('model', None, None)}}},
 'mlp': {'block': {'input': {'module': {'sharded': {'dense': {'bias': PartitionSpec(None, 'model', None),
                                                              'kernel': PartitionSpec(None, 'model', None, None)},
                                                    'pre_norm': {'scale': PartitionSpec(None, 'model', None)}}}},
                   'output': {'module': {'sharded': {'dense': {'bias': PartitionSpec(None, 'model', None),
                                                               'kernel': PartitionSpec(None, 'model', None, None)}}}}}},
 'output_layer': {'module': {'sharded': {'bias': PartitionSpec('model', None),
                                         'kernel': PartitionSpec('model', None, None)}}}}

All parameters in the model have a partitioning over the model axis. For the input and output layer, this is over the first dimension, while for the MLP blocks, this is over the second dimension. This is because the first dimension of the MLPs are the number of layer (i.e. the scan). This also demonstrates how our implementation works well under function transformations like scan, vmap, etc. Since we do not apply FSDP for now, the parameters are not partitioned over the data axis. The several sub-keys in the parameter PyTree are due to the stacking and wrapping of the modules (e.g. sharded introduced by ModelParallelismWrapper, module introduced by TPDense). Alternatively, some of these wrapper could be rewritten into functions to avoid the sub-keys.

We can now continue with the initialization:

[19]:
init_tp_fn = jax.jit(
    shard_map(
        functools.partial(init_tp, model=model_tp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=state_tp_specs,
        check_rep=False,
    ),
)
state_tp = init_tp_fn(model_init_rng, batch.inputs)

We inspect the shapes of the parameters below.

[20]:
print("TP Parameters - Input Layer")
pprint(jax.tree_map(lambda x: x.shape, state_tp.params["input_layer"]["module"]["sharded"]))
TP Parameters - Input Layer
{'bias': Partitioned(value=(4, 128), names=('model', None), mesh=None),
 'kernel': Partitioned(value=(4, 784, 128),
                       names=('model', None, None),
                       mesh=None)}

The input layer uses a gather strategy, such that its input size is the full feature size (784), but its output is split over model devices (\(512 / 4 = 128\)).

[21]:
print("TP Parameters - MLP Layers Input")
pprint(
    jax.tree_map(lambda x: x.shape, state_tp.params["mlp"]["block"]["input"]["module"]["sharded"])
)
print()
print("TP Parameters - MLP Layers Output")
pprint(
    jax.tree_map(lambda x: x.shape, state_tp.params["mlp"]["block"]["output"]["module"]["sharded"])
)
TP Parameters - MLP Layers Input
{'dense': {'bias': Partitioned(value=(3, 4, 128),
                               names=(None, 'model', None),
                               mesh=None),
           'kernel': Partitioned(value=(3, 4, 512, 128),
                                 names=(None, 'model', None, None),
                                 mesh=None)},
 'pre_norm': {'scale': Partitioned(value=(3, 4, 512),
                                   names=(None, 'model', None),
                                   mesh=None)}}

TP Parameters - MLP Layers Output
{'dense': {'bias': Partitioned(value=(3, 4, 512),
                               names=(None, 'model', None),
                               mesh=None),
           'kernel': Partitioned(value=(3, 4, 128, 512),
                                 names=(None, 'model', None, None),
                                 mesh=None)}}

The MLP input layer uses a gather strategy, such that it also has the full feature size as input, but its output is split over model devices (\(512 / 4 = 128\)). Note that the norm layer has different scaling parameters for each device. This is usually not a problem, since the norm layer is usually followed by a linear layer, which allows for scaling of the weights. Still, it’s a difference to the single device case, which is important to keep in mind, and could be shared across devices if needed.

The MLP output layer follows the scatter pattern, such that its input is split over model devices (\(512 / 4 = 128\)), but its output is the full feature size.

[22]:
print("TP Parameters - Output Layer")
pprint(jax.tree_map(lambda x: x.shape, state_tp.params["output_layer"]["module"]["sharded"]))
TP Parameters - Output Layer
{'bias': Partitioned(value=(4, 10), names=('model', None), mesh=None),
 'kernel': Partitioned(value=(4, 128, 10),
                       names=('model', None, None),
                       mesh=None)}

Finally, the final output layer follows the scatter pattern, such that its input is split over model devices (\(512 / 4 = 128\)), but its output is the full number of classes. Note that whether we manually implement the communication or use the TPDense communication does not have an impact on the feature size.

Another aspect to check is whether the initialization across devices works as expected. Since each device holds a different part of the weight matrix, we expect them to be initialized differently. We can check this by inspecting the parameters.

[23]:
state_tp.params["mlp"]["block"]["input"]["module"]["sharded"]["dense"]["kernel"].value[:, :, 0, 0]
[23]:
Array([[-0.06087485, -0.04099965,  0.04802493, -0.00385336],
       [ 0.0586801 , -0.01241772, -0.00626128,  0.00607279],
       [-0.05654007,  0.02550504, -0.02855512, -0.08177456]],      dtype=float32)

The above cell prints the kernel of the MLP input layer over the layer axis and devices. We can see that the parameters are indeed initialized differently across devices, and thus we can continue to train the model.

Training with Tensor Parallelism

The training loop is identical to the examples in the previous notebooks. The loss function is a simple cross-entropy loss, where we only calculate the loss for the first device.

[24]:
def loss_fn_tp(
    params: PyTree,
    apply_fn: Any,
    batch: Batch,
    rng: jax.Array,
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, (config.data_axis_name, config.model_axis_name))
    # Remaining computation is the same as before for single device.
    logits = apply_fn(
        {"params": params},
        batch.inputs,
        train=True,
        rngs={"dropout": dropout_rng},
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = np.prod(batch.labels.shape)
    # Mask out loss and accuracy for model devices except first one.
    model_idx = jax.lax.axis_index(config.model_axis_name)
    loss = jnp.where(model_idx != 0, 0.0, loss)
    correct_pred = jnp.where(model_idx != 0, False, correct_pred)
    batch_size = jnp.where(model_idx != 0, 0, batch_size)
    # Collect metrics and return loss.
    step_metrics = {
        "loss": (loss.sum(), batch_size),
        "accuracy": (correct_pred.sum(), batch_size),
    }
    loss = loss.mean()
    return loss, step_metrics

In the training, we want to support 2D parallelism with (fully-sharded) data parallelism and tensor parallelism. Thus, after having determined the gradients per device, we need to sync them over the data axis accordingly. For this, we can reuse the sync_gradients functions from our fully-sharded data parallelism implementation. We then apply the optimizer update as usual.

Finally, we can summarize all in the training step below. It is identical to the fully-sharded data parallelism training step up to syncing gradients over the data and model axis (which is now 2D). We then apply the optimizer update as usual.

[25]:
def train_step_tp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
    loss_fn: Callable = loss_fn_tp,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = sync_gradients(grads, (config.data_axis_name, config.model_axis_name))
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=(config.data_axis_name, config.model_axis_name)),
            step_metrics,
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

With the training loop implemented, we can now train the model. We will train the model on a simple random data classification task, and expect the model to learn to classify the data with high accuracy. We will use a small batch size to run the model easily on a CPU-only system.

[26]:
train_step_tp_fn = jax.jit(
    shard_map(
        train_step_tp,
        mesh,
        in_specs=(state_tp_specs, P(), P(config.data_axis_name)),
        out_specs=(state_tp_specs, P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)
state_shapes, metric_shapes = jax.eval_shape(
    train_step_tp_fn,
    state_tp,
    None,
    batch,
)
metrics_tp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_tp, metrics_tp = train_step_tp_fn(state_tp, metrics_tp, batch)
/home/plippe/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,512,128]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,128,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]), ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,512,128]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,128,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]), ShapedArray(float32[1,128]), ShapedArray(float32[1,784,128]), ShapedArray(float32[3,1,128]), ShapedArray(float32[3,1,512,128]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,512]), ShapedArray(float32[3,1,128,512]), ShapedArray(float32[1,10]), ShapedArray(float32[1,128,10]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"

We run the model for 15 steps and print the final loss and accuracy.

[27]:
for _ in range(15):
    state_tp, metrics_tp = train_step_tp_fn(state_tp, metrics_tp, batch)
final_metrics_tp = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_tp, final_metrics_tp = train_step_tp_fn(state_tp, final_metrics_tp, batch)
print_metrics(final_metrics_tp, title="Final Metrics - Tensor Parallelism")
 Final Metrics - Tensor Parallelism
accuracy: 1.000000
loss: 0.000030

As we expected, the model is able to learn the task with high accuracy. We can now continue to the next part, where we discuss a more efficient implementation exploiting the compute and communication overlap.

Intermediate Summary

In this part, we discussed the principles of tensor parallelism, and how to implement it in JAX. We implemented a simple MLP model with tensor parallelism, and trained it on a simple random data classification task. We also discussed the sharding of the parameters. In the next part, we will discuss how to maximize the efficiency of tensor parallelism with compute-communication overlaps.

References and Resources

[Shoeybi et al., 2019] Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J. and Catanzaro, B., 2019. Megatron-lm: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053. Paper link

[Wang and Komatsuzaki, 2021] Wang, B., and Komatsuzaki, A., 2021. Mesh transformer jax. GitHub link

[Xu et al., 2021] Xu, Y., Lee, H., Chen, D., Hechtman, B., Huang, Y., Joshi, R., Krikun, M., Lepikhin, D., Ly, A., Maggioni, M. and Pang, R., 2021. GSPMD: general and scalable parallelization for ML computation graphs. arXiv preprint arXiv:2105.04663. Paper link

[Dehghani et al., 2022] Dehghani, M., Gritsenko, A., Arnab, A., Minderer, M. and Tay, Y., 2022. Scenic: A JAX library for computer vision research and beyond. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 21393-21398). Paper link

[Yoo et al., 2022] Yoo, J., Perlin, K., Kamalakara, S.R. and Araújo, J.G., 2022. Scalable training of language models using JAX pjit and TPUv4. arXiv preprint arXiv:2204.06514. Paper link

[Chowdhery et al., 2023] Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H.W., Sutton, C., Gehrmann, S., Schuh, P., et al., 2023. Palm: Scaling language modeling with pathways. Journal of Machine Learning Research, 24(240), pp.1-113. Paper link

[Anil et al., 2023] Anil, R., Dai, A.M., Firat, O., Johnson, M., Lepikhin, D., Passos, A., Shakeri, S., Taropa, E., Bailey, P., Chen, Z. and Chu, E., 2023. Palm 2 technical report. arXiv preprint arXiv:2305.10403. Paper link

[Dehghani et al., 2023] Dehghani, M., Djolonga, J., Mustafa, B., Padlewski, P., Heek, J., Gilmer, J., Steiner, A.P., Caron, M., Geirhos, R., Alabdulmohsin, I., Jenatton, R., et al., 2023. Scaling vision transformers to 22 billion parameters. In International Conference on Machine Learning (pp. 7480-7512). PMLR. Paper link

[McKinney, 2023] McKinney, A., 2023. A Brief Overview of Parallelism Strategies in Deep Learning. Blog post link

[Huggingface, 2024] Huggingface, 2024. Model Parallelism. Documentation link

[Google, 2024] JAX Team Google, 2024. SPMD multi-device parallelism with shard_map. Notebook link

[OpenAI, 2024] OpenAI, 2024. GPT-4. Technical Report

[Google, 2024] Gemini Team Google Deepmind, 2024. Gemini. Technical Report


Star our repository If you found this tutorial helpful, consider ⭐-ing our repository.
Ask questions For any questions, typos, or bugs that you found, please raise an issue on GitHub.