最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python 3.x - Using vmap to parallelize multiple models input in JAXFlax.nnx - Stack Overflow

programmeradmin3浏览0评论

I have a function that is roughly as follows

from flax import nnx
from jax import Array
from typing import List

def predict(models: List[nnx.Module], imgs: Array):
    for i, agent in enumerate(agents):
       prediction = models(imgs[i, ...])

I want to avoid looping over these models to get the predictions. Ideally I use jax.vmap or nnx.vmap to create a new function predict that does everything parallelized on the GPU. However, I (obviously) can't pass a list as argument.

My desired solution:

from jax import vmap

def predict_single(model, img):
   return model(img)

predict = vmap(predict_single)

Error message

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

How should I solve this?

I have a function that is roughly as follows

from flax import nnx
from jax import Array
from typing import List

def predict(models: List[nnx.Module], imgs: Array):
    for i, agent in enumerate(agents):
       prediction = models(imgs[i, ...])

I want to avoid looping over these models to get the predictions. Ideally I use jax.vmap or nnx.vmap to create a new function predict that does everything parallelized on the GPU. However, I (obviously) can't pass a list as argument.

My desired solution:

from jax import vmap

def predict_single(model, img):
   return model(img)

predict = vmap(predict_single)

Error message

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

How should I solve this?

Share Improve this question asked Mar 17 at 16:24 elderlyelderly 111 bronze badge
Add a comment  | 

1 Answer 1

Reset to default 0

You can do something like this using jax.lax.switch . Assuming you have batched imgs that you want to pass to the associated model, it would look something like this:

func = jax.vmap(lambda i, img: jax.lax.switch(i, models, img))
result = func(jnp.arange(len(imgs)), imgs)

Note that this requires every model to return an array with the same structure (i.e. shape / dtype). If your models don't meet this requirement, then there is no way to execute them within vmap and your best bet would be to use a normal Python for loop.

发布评论

评论列表(0)

  1. 暂无评论