The method keras.Model.set_weights
seems to only take trainable weights. Non-trainable weights such as those from normalization layers cannot be imported this way. This is problematic, since in Keras with JAX backend, we can only perform stateless operations. That is; update the weights separate from the model. Therefore, before saving a model using keras.Model.save
, it is required to load in the weights. However, since non-trainable weights cannot be loaded in (and therefore not saved in the .keras
format), the saved model will underperform.
Is it possible to load/set the non-trainable weights in a Keras model? More generally, is there any way to save the complete model, including non-trainable weights when using JAX as backend?