I'm very new to JAX, so there might be some concepts I don't fully understand yet. I'm trying to adapt my problem to JAX and use MCTX.
For my particular problem, the input is a string (the agent, if you prefer RL terminology). I'm trying to launch multiple MCTS runs simultaneously using vmap. Since my input is a string, I vectorized it by assigning an index to each character. Up to this point, everything works fine—my inputs are all the same size, etc.
The problem arises when transitioning from one state to another within MCTS. To change states, I need to call an external function that performs a transformation on the string. Ideally, I would need to convert the vector back into a string, pass it to this external function, receive the transformed string for the next state, and then convert it back into a vector.
I'm unsure whether this is feasible in JAX or if it's well-suited to my particular problem. Is there a workaround? I'm also concerned that implementing a workaround might negate the performance benefits of using JAX in the first place.
Any help would be greatly appreciated!