最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Why is Jax treating floating point values as tracers rather than concretizing them when nesting jitted functions? - Sta

programmeradmin5浏览0评论

I am doing some physics simulations using jax, and this involves a function called the Hamiltonian defined as follows:

# Constructing the Hamiltonian
@partial(jit, static_argnames=['n', 'omega'])
def hamiltonian(n: int, omega: float):
    """Construct the Hamiltonian for the system."""
    H = omega *  create(n) @ annhilate(n)
    return H 

and then a bigger function def solve_diff(n, omega, kappa, alpha0): that is defined as follows:

@partial(jit, static_argnames=['n', 'omega'])
def solve_diff(n, omega, kappa, alpha0):
    # Some functionality that uses kappa and alpha0
    
    H = hamiltonian(n, omega)

    # returns an expectation value

When I try to compute the gradient of this function using jax.grad

n = 16   
omega = 1.0   
kappa = 0.1  
alpha0 = 1.0 

# Compute gradients with respect to omega, kappa, and alpha0
grad_population = grad(solve_diff, argnums=(1, 2, 3))
grads = grad_population(n, omega, kappa, alpha0)

print(f"Gradient w.r.t. omega: {grads[0]}")
print(f"Gradient w.r.t. kappa: {grads[1]}")
print(f"Gradient w.r.t. alpha0: {grads[2]}")

it outputs the following error:

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.ad.JVPTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace> with
  primal = 1.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding(). The error was:
TypeError: unhashable type: 'JVPTracer'

Though, running solve_diff(16,1.0,0.1,1.0) on its own works as expected.

Now if I remove omega from the list of static variables for both the hamiltonian function and the solve_diff, the grad is output as expected.

This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls, both n and omega are constants and indeed should not change between function calls.

I am doing some physics simulations using jax, and this involves a function called the Hamiltonian defined as follows:

# Constructing the Hamiltonian
@partial(jit, static_argnames=['n', 'omega'])
def hamiltonian(n: int, omega: float):
    """Construct the Hamiltonian for the system."""
    H = omega *  create(n) @ annhilate(n)
    return H 

and then a bigger function def solve_diff(n, omega, kappa, alpha0): that is defined as follows:

@partial(jit, static_argnames=['n', 'omega'])
def solve_diff(n, omega, kappa, alpha0):
    # Some functionality that uses kappa and alpha0
    
    H = hamiltonian(n, omega)

    # returns an expectation value

When I try to compute the gradient of this function using jax.grad

n = 16   
omega = 1.0   
kappa = 0.1  
alpha0 = 1.0 

# Compute gradients with respect to omega, kappa, and alpha0
grad_population = grad(solve_diff, argnums=(1, 2, 3))
grads = grad_population(n, omega, kappa, alpha0)

print(f"Gradient w.r.t. omega: {grads[0]}")
print(f"Gradient w.r.t. kappa: {grads[1]}")
print(f"Gradient w.r.t. alpha0: {grads[2]}")

it outputs the following error:

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.ad.JVPTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace> with
  primal = 1.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding(). The error was:
TypeError: unhashable type: 'JVPTracer'

Though, running solve_diff(16,1.0,0.1,1.0) on its own works as expected.

Now if I remove omega from the list of static variables for both the hamiltonian function and the solve_diff, the grad is output as expected.

This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls, both n and omega are constants and indeed should not change between function calls.

Share Improve this question asked Apr 2 at 8:07 yousef elbrolosyyousef elbrolosy 1134 bronze badges New contributor yousef elbrolosy is a new contributor to this site. Take care in asking for clarification, commenting, and answering. Check out our Code of Conduct.
Add a comment  | 

1 Answer 1

Reset to default 1

The fundamental issue is that you cannot differentiate with respect to a static variable, and if you try to do so you will get the error you observed.

This is confusing me, because I no longer know what qualifies as static or dynamic variables anymore, from the definition that static variables does not change between function calls

In JAX, the term "static" does not have to do with whether the variable is changed between function calls. Rather, a static variable is a variable that does not participate in tracing, which is the mechanism used to compute transformations like vmap, grad, jit, etc. When you differentiate with respect to a variable, it is no longer static because it is participating in the autodiff transformation, and trying to treat it as static later in the computation will lead to an error.

For a discussion of transformations, tracing, and related concepts, I'd start with JAX Key Concepts: transformations.

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论