I have setup a snippet on Colab here with
jax.__version__ # 0.4.33 9Feb2025
orbax.checkpoint.__version__ # 0.6.4 9Feb2025
It quite difficult to follow the flax/orbax changes in the save/restore (simple) model even with following the "latest" documentation of these two packages.
I have managed to cooked something but I was wandering if I'm doing the right thing using 8TPUs on Colab; for instance it semms that one can save a single instance of the Model among the 8 existing ones (ie. the use of flax.jax_utils.unreplicate
seems necessary
ckpt = {'model': flax.jax_utils.unreplicate(model_state)}
) At restoration in the same environment after
target={'model': abstract_state} # a Training State quite dummy
chpt_restored = checkpoint_manager.restore(checkpoint_manager.latest_step(), items=target)
one restaure 8 vesions using
new_model_state = flax.jax_utils.replicate(chpt_restored['model'])
but this is 8 replicated version of the model with the same instance.
It may be foreseen to behaves like that, but I wandering how one can resume a first training session, to continue the training as one may use s unique instance at the second training session? Hope that I have been clear. Any comment on the Colab snippet is welcome.