最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Flax nnxjax: tree.map for layers of incongruent size - Stack Overflow

programmeradmin1浏览0评论

I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with different out_features?

import jax
from flax import nnx
from functools import partial

if __name__ == '__main__':

    session_sizes = {
        'a':2,
        'b':3,
        'c':4,
        'd':5,
        'e':6,
    }
    dz = 2

    rngs = nnx.Rngs(0)
    
    my_linear = partial(
        nnx.Linear,
        use_bias = False,
        in_features = dz,
        rngs=rngs )
    
    def my_linear_wrapper(a):
        return my_linear( out_features=a )

    q_s = jax.tree.map(my_linear_wrapper, session_sizes)

    for k in session_sizes.keys():
        print(q_s[k].kernel)

So in this case, we would need a tree of layers that will take our 2 in_features into spaces of 2, ..., 6 out_features.

The function my_linear_wrapper is sort of a workaround for the original solution we had in mind, which is to map in very much the same fashion as we're doing, but instead use (something like) the @nnx.split_rngs function decorator.

Is there a way to use nnx.split_rngs on my_linear in order to map over the rng argument to nnx.Linear?

I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with different out_features?

import jax
from flax import nnx
from functools import partial

if __name__ == '__main__':

    session_sizes = {
        'a':2,
        'b':3,
        'c':4,
        'd':5,
        'e':6,
    }
    dz = 2

    rngs = nnx.Rngs(0)
    
    my_linear = partial(
        nnx.Linear,
        use_bias = False,
        in_features = dz,
        rngs=rngs )
    
    def my_linear_wrapper(a):
        return my_linear( out_features=a )

    q_s = jax.tree.map(my_linear_wrapper, session_sizes)

    for k in session_sizes.keys():
        print(q_s[k].kernel)

So in this case, we would need a tree of layers that will take our 2 in_features into spaces of 2, ..., 6 out_features.

The function my_linear_wrapper is sort of a workaround for the original solution we had in mind, which is to map in very much the same fashion as we're doing, but instead use (something like) the @nnx.split_rngs function decorator.

Is there a way to use nnx.split_rngs on my_linear in order to map over the rng argument to nnx.Linear?

Share Improve this question edited 2 days ago jworrell asked 2 days ago jworrelljworrell 33 bronze badges New contributor jworrell is a new contributor to this site. Take care in asking for clarification, commenting, and answering. Check out our Code of Conduct.
Add a comment  | 

1 Answer 1

Reset to default 0

split_rngs is mostly useful when you are going to pass the Rngs through a transform like vmap, here you want to produce variable sized Modules so the current solution is the way to go. Because of how partial works you can simplify this to:

din = 2
rngs = nnx.Rngs(0)

my_linear = functools.partial(
  nnx.Linear, din, use_bias=False, rngs=rngs
)

q_s = jax.tree.map(my_linear, session_sizes)

for k in session_sizes.keys():
  print(q_s[k].kernel)
发布评论

评论列表(0)

  1. 暂无评论