How to Manage Immutable State Variables in Flax and JAX

He Zhang
2 min readJan 8, 2024

The Problem

Sometimes, we want to create an immutable state or constant weight in a machine-learning layer. This is usually achieved using a parameter that turns off the gradient in Pytorch and Tensorflow, but how do we achieve it in the Flax framework? The first attempt is to do a similar setup using self.param in Flax, however, we can see it is expensive to do that if our initializing function is expensive since the model is being re-constructed every time when we call it.

The initializing function is called whenever we try to apply the model. Source: Author.

First Attempt: Using JIT

Okay, if the model is being reconstructed every time we call it, can we optimize it so it won’t be reconstructed? In JAX, we used functional programming, so we need to keep track of the states, and this can be optimized using jax.jit. The JIT will optimize it so that the model won’t be reconstructed every time we call it.

Using JIT solves the issue of repeatedly calling initializing function. Source: Author.

However, this approach messed up the model parameter scope and the trainable model size when we tried to print it, since we cannot distinguish whether a parameter is trainable or not in JAX (not OOP, and therefore cannot use syntax like param.is_trainable().)

--

--