I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with different out_features?
import jax
from flax import nnx
from functools import partial
if __name__ == '__main__':
session_sizes = {
'a':2,
'b':3,
'c':4,
'd':5,
'e':6,
}
dz = 2
rngs = nnx.Rngs(0)
my_linear = partial(
nnx.Linear,
use_bias = False,
in_features = dz,
rngs=rngs )
def my_linear_wrapper(a):
return my_linear( out_features=a )
q_s = jax.tree.map(my_linear_wrapper, session_sizes)
for k in session_sizes.keys():
print(q_s[k].kernel)
So in this case, we would need a tree of layers that will take our 2 in_features into spaces of 2, ..., 6 out_features.
The function my_linear_wrapper is sort of a workaround for the original solution we had in mind, which is to map in very much the same fashion as we're doing, but instead use (something like) the @nnx.split_rngs function decorator.
Is there a way to use nnx.split_rngs on my_linear in order to map over the rng argument to nnx.Linear?
I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with different out_features?
import jax
from flax import nnx
from functools import partial
if __name__ == '__main__':
session_sizes = {
'a':2,
'b':3,
'c':4,
'd':5,
'e':6,
}
dz = 2
rngs = nnx.Rngs(0)
my_linear = partial(
nnx.Linear,
use_bias = False,
in_features = dz,
rngs=rngs )
def my_linear_wrapper(a):
return my_linear( out_features=a )
q_s = jax.tree.map(my_linear_wrapper, session_sizes)
for k in session_sizes.keys():
print(q_s[k].kernel)
So in this case, we would need a tree of layers that will take our 2 in_features into spaces of 2, ..., 6 out_features.
The function my_linear_wrapper is sort of a workaround for the original solution we had in mind, which is to map in very much the same fashion as we're doing, but instead use (something like) the @nnx.split_rngs function decorator.
Is there a way to use nnx.split_rngs on my_linear in order to map over the rng argument to nnx.Linear?
Share Improve this question edited 2 days ago jworrell asked 2 days ago jworrelljworrell 33 bronze badges New contributor jworrell is a new contributor to this site. Take care in asking for clarification, commenting, and answering. Check out our Code of Conduct.1 Answer
Reset to default 0split_rngs
is mostly useful when you are going to pass the Rngs
through a transform like vmap
, here you want to produce variable sized Modules so the current solution is the way to go. Because of how partial
works you can simplify this to:
din = 2
rngs = nnx.Rngs(0)
my_linear = functools.partial(
nnx.Linear, din, use_bias=False, rngs=rngs
)
q_s = jax.tree.map(my_linear, session_sizes)
for k in session_sizes.keys():
print(q_s[k].kernel)