The Problem
Sometimes, we want to create an immutable state or constant weight in a machine-learning layer. This is usually achieved using a parameter that turns off the gradient in Pytorch and Tensorflow, but how do we achieve it in the Flax framework? The first attempt is to do a similar setup using self.param
in Flax, however, we can see it is expensive to do that if our initializing function is expensive since the model is being re-constructed every time when we call it.
First Attempt: Using JIT
Okay, if the model is being reconstructed every time we call it, can we optimize it so it won’t be reconstructed? In JAX, we used functional programming, so we need to keep track of the states, and this can be optimized using jax.jit
. The JIT will optimize it so that the model won’t be reconstructed every time we call it.
However, this approach messed up the model parameter scope and the trainable model size when we tried to print it, since we cannot distinguish whether a parameter is trainable or not in JAX (not OOP, and therefore cannot use syntax like param.is_trainable()
.)