I have a list of a objects, each of which has a function to be applied on a slice of a jax.numpy.array
. There are n
objects and n
corresponding slices. How can I vectorise this using vmap
?
For example, for the following code snippet:
import jax
import jax.numpy as jnp
class Obj:
def __init__(self, i):
self.i = i
def f1(self, x): return (x - self.i)
x = jnp.arange(9).reshape(3, 3).astype(jnp.float32)
functions_obj = [Obj(1).f1, Obj(2).f1, Obj(3).f1]
how would I apply the functions in functions_obj
to slices of x
?
More details, probably not relevant:
My specific use-case is running the member functions of a lot of Reinforcement Learning Gym
environment objects on slices of an actions
array, but I believe my problem is more general and I formulated it as above. (P.S.: I know about AsyncVectorEnv
by the way but that does not solve my problem as I am not trying to run the step
function).
I have a list of a objects, each of which has a function to be applied on a slice of a jax.numpy.array
. There are n
objects and n
corresponding slices. How can I vectorise this using vmap
?
For example, for the following code snippet:
import jax
import jax.numpy as jnp
class Obj:
def __init__(self, i):
self.i = i
def f1(self, x): return (x - self.i)
x = jnp.arange(9).reshape(3, 3).astype(jnp.float32)
functions_obj = [Obj(1).f1, Obj(2).f1, Obj(3).f1]
how would I apply the functions in functions_obj
to slices of x
?
More details, probably not relevant:
My specific use-case is running the member functions of a lot of Reinforcement Learning Gym
environment objects on slices of an actions
array, but I believe my problem is more general and I formulated it as above. (P.S.: I know about AsyncVectorEnv
by the way but that does not solve my problem as I am not trying to run the step
function).
1 Answer
Reset to default 1Use jax.lax.switch
to select between the functions in the list and map over the desired axis of x
at the same time:
def apply_func_obj(i, x_slice):
return jax.lax.switch(i, functions_obj, x_slice)
indices = jnp.arange(len(functions_obj))
# Use vmap to apply the function element-wise
results = jax.vmap(apply_func_obj, in_axes=(0, 0))(indices, x)