How to Fine-tune a Flax or Jax Model with HuggingFace Models as Sub-layers?
Are you still freezing your HuggingFace model when fine tuning??
If you search for the fine-tuning HuggingFace blogs online you might see people doing all sorts of things.
One of them is to instantiate the HuggingFace model outside of the module class without actually fine-tuning the HuggingFace model.
huggingface_model = transformers.FlaxT5Encoder(config)
class MyModel(nn.Module):
def setup(self):
self._huggingface_model = huggingface_model
self._head = nn.Dense(1)
def __call__(self, x):
x = self._huggingface_model(x).logits
x = self._head(x[:, 0])
return x
This is not the best way to do it!
Another common tutorials you can find is to only tune the HuggingFace model itself without any modification.
state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=adamw(weight_decay=0.01),
logits_function=eval_function,
loss_function=loss_function,
)
This is also not the best way as it’s limiting to the models defined by HuggingFace.
How do we modify the HuggingFace models as LEGO blocks and still be able to fine-tune them?
The Correct Way
It’s simple: each HuggingFace model has a module
member which is the Flax module class that we can use.