I have a function that is roughly as follows
from flax import nnx
from jax import Array
from typing import List
def predict(models: List[nnx.Module], imgs: Array):
for i, agent in enumerate(agents):
prediction = models(imgs[i, ...])
I want to avoid looping over these models to get the predictions. Ideally I use jax.vmap
or nnx.vmap
to create a new function predict that does everything parallelized on the GPU. However, I (obviously) can't pass a list as argument.
My desired solution:
from jax import vmap
def predict_single(model, img):
return model(img)
predict = vmap(predict_single)
Error message
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
How should I solve this?
I have a function that is roughly as follows
from flax import nnx
from jax import Array
from typing import List
def predict(models: List[nnx.Module], imgs: Array):
for i, agent in enumerate(agents):
prediction = models(imgs[i, ...])
I want to avoid looping over these models to get the predictions. Ideally I use jax.vmap
or nnx.vmap
to create a new function predict that does everything parallelized on the GPU. However, I (obviously) can't pass a list as argument.
My desired solution:
from jax import vmap
def predict_single(model, img):
return model(img)
predict = vmap(predict_single)
Error message
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
How should I solve this?
Share Improve this question asked Mar 17 at 16:24 elderlyelderly 111 bronze badge1 Answer
Reset to default 0You can do something like this using jax.lax.switch
. Assuming you have batched imgs
that you want to pass to the associated model, it would look something like this:
func = jax.vmap(lambda i, img: jax.lax.switch(i, models, img))
result = func(jnp.arange(len(imgs)), imgs)
Note that this requires every model to return an array with the same structure (i.e. shape / dtype). If your models don't meet this requirement, then there is no way to execute them within vmap
and your best bet would be to use a normal Python for
loop.