I am doing some physics simulations using jax, and this involves a function called the Hamiltonian defined as follows:
# Constructing the Hamiltonian
@partial(jit, static_argnames=['n', 'omega'])
def hamiltonian(n: int, omega: float):
"""Construct the Hamiltonian for the system."""
H = omega * create(n) @ annhilate(n)
return H
and then a bigger function def solve_diff(n, omega, kappa, alpha0):
that is defined as follows:
@partial(jit, static_argnames=['n', 'omega'])
def solve_diff(n, omega, kappa, alpha0):
# Some functionality that uses kappa and alpha0
H = hamiltonian(n, omega)
# returns an expectation value
When I try to compute the gradient of this function using jax.grad
n = 16
omega = 1.0
kappa = 0.1
alpha0 = 1.0
# Compute gradients with respect to omega, kappa, and alpha0
grad_population = grad(solve_diff, argnums=(1, 2, 3))
grads = grad_population(n, omega, kappa, alpha0)
print(f"Gradient w.r.t. omega: {grads[0]}")
print(f"Gradient w.r.t. kappa: {grads[1]}")
print(f"Gradient w.r.t. alpha0: {grads[2]}")
it outputs the following error:
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.ad.JVPTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace> with
primal = 1.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding(). The error was:
TypeError: unhashable type: 'JVPTracer'
Though, running solve_diff(16,1.0,0.1,1.0)
on its own works as expected.
Now if I remove omega
from the list of static variables for both the hamiltonian
function and the solve_diff
, the grad is output as expected.
This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls, both n
and omega
are constants and indeed should not change between function calls.
I am doing some physics simulations using jax, and this involves a function called the Hamiltonian defined as follows:
# Constructing the Hamiltonian
@partial(jit, static_argnames=['n', 'omega'])
def hamiltonian(n: int, omega: float):
"""Construct the Hamiltonian for the system."""
H = omega * create(n) @ annhilate(n)
return H
and then a bigger function def solve_diff(n, omega, kappa, alpha0):
that is defined as follows:
@partial(jit, static_argnames=['n', 'omega'])
def solve_diff(n, omega, kappa, alpha0):
# Some functionality that uses kappa and alpha0
H = hamiltonian(n, omega)
# returns an expectation value
When I try to compute the gradient of this function using jax.grad
n = 16
omega = 1.0
kappa = 0.1
alpha0 = 1.0
# Compute gradients with respect to omega, kappa, and alpha0
grad_population = grad(solve_diff, argnums=(1, 2, 3))
grads = grad_population(n, omega, kappa, alpha0)
print(f"Gradient w.r.t. omega: {grads[0]}")
print(f"Gradient w.r.t. kappa: {grads[1]}")
print(f"Gradient w.r.t. alpha0: {grads[2]}")
it outputs the following error:
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.ad.JVPTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace> with
primal = 1.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding(). The error was:
TypeError: unhashable type: 'JVPTracer'
Though, running solve_diff(16,1.0,0.1,1.0)
on its own works as expected.
Now if I remove omega
from the list of static variables for both the hamiltonian
function and the solve_diff
, the grad is output as expected.
This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls, both n
and omega
are constants and indeed should not change between function calls.
1 Answer
Reset to default 1The fundamental issue is that you cannot differentiate with respect to a static variable, and if you try to do so you will get the error you observed.
This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls
In JAX, the term "static" does not have to do with whether the variable is changed between function calls. Rather, a static variable is a variable that does not participate in tracing, which is the mechanism used to compute transformations like vmap
, grad
, jit
, etc. When you differentiate with respect to a variable, it is no longer static because it is participating in the autodiff transformation, and trying to treat it as static later in the computation will lead to an error.
For a discussion of transformations, tracing, and related concepts, I'd start with JAX Key Concepts: transformations.