There is a function in my code base that is "already vmapped", i.e. when fed an array of shape (M, N) it outputs another array of shape (M, N). I would like to take the "row-wise Jacobian" of this function: i.e. a function that returns an array of shape (M, N, N). I've achieved this so far in hacky way by adding a dummy extra dimension and vmapping, as illustrated in the example below, but it really feels like there should be a better way to do this. Does anyone have any ideas?
Example of what I want:
import jax
from jax import numpy as jnp
rng = jax.random.PRNGKey(42)
A = jax.random.normal(rng, shape=(128, 16, 16))
# This is the function I would like to take the row-wise Jacobian of
def already_vmapped(x, A):
vmap_mul = jax.vmap(jnp.matmul)
return vmap_mul(A, x)
# This is how I am doing it now
function = lambda x, A: jax.vmap(jax.jacobian(lambda x, A: already_vmapped(x[None,], A[None,]),
argnums=0))(x, A).squeeze()
x = jnp.ones((128, 16))
print(jnp.allclose(function(x, A), A)) # True
# A slightly cleaner way that is too memory intensive
function_mem = lambda x, A: jnp.diagonal(jax.jacobian(already_vmapped, argnums=0)(x, A),
offset=0, axis1=0, axis2=2).transpose(2, 0, 1)
print(jnp.allclose(function_mem(x, A), A)) # True
I understand that the absolute cleanest way would just to not vmap the original function in the first place, but for whatever reason that is not easy to undo right now given my codebase. Any other suggestions are welcome!
There is a function in my code base that is "already vmapped", i.e. when fed an array of shape (M, N) it outputs another array of shape (M, N). I would like to take the "row-wise Jacobian" of this function: i.e. a function that returns an array of shape (M, N, N). I've achieved this so far in hacky way by adding a dummy extra dimension and vmapping, as illustrated in the example below, but it really feels like there should be a better way to do this. Does anyone have any ideas?
Example of what I want:
import jax
from jax import numpy as jnp
rng = jax.random.PRNGKey(42)
A = jax.random.normal(rng, shape=(128, 16, 16))
# This is the function I would like to take the row-wise Jacobian of
def already_vmapped(x, A):
vmap_mul = jax.vmap(jnp.matmul)
return vmap_mul(A, x)
# This is how I am doing it now
function = lambda x, A: jax.vmap(jax.jacobian(lambda x, A: already_vmapped(x[None,], A[None,]),
argnums=0))(x, A).squeeze()
x = jnp.ones((128, 16))
print(jnp.allclose(function(x, A), A)) # True
# A slightly cleaner way that is too memory intensive
function_mem = lambda x, A: jnp.diagonal(jax.jacobian(already_vmapped, argnums=0)(x, A),
offset=0, axis1=0, axis2=2).transpose(2, 0, 1)
print(jnp.allclose(function_mem(x, A), A)) # True
I understand that the absolute cleanest way would just to not vmap the original function in the first place, but for whatever reason that is not easy to undo right now given my codebase. Any other suggestions are welcome!
Share Improve this question asked Mar 31 at 11:26 Kimon PKimon P 312 bronze badges1 Answer
Reset to default 1I think what you're already doing is more-or-less the best approach. You want to vmap
the jacobian
over the rows, and within each row you want to compute a size-1 batch of your original "already vmapped" function.
For clarity, I'd probably re-express your initial answer this way:
def f_single_batch(x, A):
return already_vmapped(x[None], A[None]).squeeze(0)
result = jax.vmap(jax.jacobian(f_single_batch, 0))(x, A)
A slightly more direct approach to this might look like this:
result = jax.vmap(jax.jacobian(already_vmapped, 0))(x[:, None], A[:, None]).squeeze((1, 3))
But I would lean toward the first version because it's easier to understand, and anybody reading the code (including your future self) will thank you.