How to Write a Training Loop For Data Parallelism

He Zhang
3 min readFeb 6, 2024

This article will explain my understanding of the logics behind train.train_and_evaluate function you seen when writing a training loop.

The function signature:

train_and_evaluate(config) -> train_state, train_iter_state

This function should take a configuration dictionary of your choice as an entry point. The return output should include the model state in Jax, or the model in other framework. The train iter state is the training data input state to indicate where we are so that we can resume later if needed.

Num Training Steps

The trick is to calculate the training steps when we are using multiple computing clusters. For that to work, we can set up the per-device-batch-size and calculate the training steps based on the total data size and epoch numbers.

The data will first be divided by the number of hosts jax.process_count() and then each host will distribute the host level shard to device level shards by jax.local_device_count(). The total training data = data_size * num_epochs // jax.process_count() per host.

The per device training steps = training data per host // (per-device-batch-size * local_device_count).

Learning Rate

Be careful the learning rate is defined for the global batch size. So if we want to use a different number of devices, we need to scale the learning rate accordingly. One trick…

--

--

He Zhang
He Zhang

No responses yet