This article will explain my understanding of the logics behind train.train_and_evaluate
function you seen when writing a training loop.
The function signature:
train_and_evaluate(config) -> train_state, train_iter_state
This function should take a configuration dictionary of your choice as an entry point. The return output should include the model state in Jax, or the model in other framework. The train iter state is the training data input state to indicate where we are so that we can resume later if needed.
Num Training Steps
The trick is to calculate the training steps when we are using multiple computing clusters. For that to work, we can set up the per-device-batch-size and calculate the training steps based on the total data size and epoch numbers.
The data will first be divided by the number of hosts jax.process_count()
and then each host will distribute the host level shard to device level shards by jax.local_device_count()
. The total training data = data_size * num_epochs // jax.process_count() per host.
The per device training steps = training data per host // (per-device-batch-size * local_device_count).
Learning Rate
Be careful the learning rate is defined for the global batch size. So if we want to use a different number of devices, we need to scale the learning rate accordingly. One trick…