最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - TracerBoolConversion error while attempting to use @jit on functions - Stack Overflow

programmeradmin2浏览0评论

In an effort to optimize an existing function used in an optimization algorithm by applying @jit, I encountered some issues. When running the following function:

import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline

jax.config.update("jax_enable_x64", True)

@jax.jit
def bspline(t_values, knots, coefficients, degree):
    """
    Generate a B-spline curve from given knots and coefficients.

    Parameters:
    - t_values: Array of parameter values where the spline is evaluated.
    - knots: Knot vector as a 1D numpy array.
    - coefficients: Control points as a 1D numpy array of shape (n,).
    - degree: Degree of the B-spline (e.g., 3 for cubic B-splines).

    Returns:
    - A numpy array of shape (num_points,) representing the B-spline curve.
    """
    def basis_function(i, k, t, knots):
        """Compute the basis function recursively."""
        if k == 0:
            return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
        else:
            denom1 = knots[i + k] - knots[i]
            denom2 = knots[i + k + 1] - knots[i + 1]

            term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
            term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0

            return term1 + term2

    # Compute the B-spline curve points
    curve_points = jnp.zeros(len(t_values))
    for i in range(len(coefficients)):
        v = basis_function(i, degree, t_values, knots)
        curve_points = curve_points + v * coefficients[i]

    return curve_points

I get the following error:

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

What I've Tried:

After consulting JAX's official documentation about this error (available here), I modified the basis_function to avoid direct boolean checks:

def basis_function(i, k, t, knots):
    """Compute the basis function recursively."""
    return jnp.where(
        k == 0,
        jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
        jnp.where(
            (knots[i + k] - knots[i]) != 0,
            (t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
            0
        ) +
        jnp.where(
            (knots[i + k + 1] - knots[i + 1]) != 0,
            (knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
            0
        )
    )

However, now I encounter a RecursionError:

RecursionError: maximum recursion depth exceeded in comparison

This recursion issue seems to stem from applying @jit, as it was not present before.

In an effort to optimize an existing function used in an optimization algorithm by applying @jit, I encountered some issues. When running the following function:

import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline

jax.config.update("jax_enable_x64", True)

@jax.jit
def bspline(t_values, knots, coefficients, degree):
    """
    Generate a B-spline curve from given knots and coefficients.

    Parameters:
    - t_values: Array of parameter values where the spline is evaluated.
    - knots: Knot vector as a 1D numpy array.
    - coefficients: Control points as a 1D numpy array of shape (n,).
    - degree: Degree of the B-spline (e.g., 3 for cubic B-splines).

    Returns:
    - A numpy array of shape (num_points,) representing the B-spline curve.
    """
    def basis_function(i, k, t, knots):
        """Compute the basis function recursively."""
        if k == 0:
            return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
        else:
            denom1 = knots[i + k] - knots[i]
            denom2 = knots[i + k + 1] - knots[i + 1]

            term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
            term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0

            return term1 + term2

    # Compute the B-spline curve points
    curve_points = jnp.zeros(len(t_values))
    for i in range(len(coefficients)):
        v = basis_function(i, degree, t_values, knots)
        curve_points = curve_points + v * coefficients[i]

    return curve_points

I get the following error:

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

What I've Tried:

After consulting JAX's official documentation about this error (available here), I modified the basis_function to avoid direct boolean checks:

def basis_function(i, k, t, knots):
    """Compute the basis function recursively."""
    return jnp.where(
        k == 0,
        jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
        jnp.where(
            (knots[i + k] - knots[i]) != 0,
            (t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
            0
        ) +
        jnp.where(
            (knots[i + k + 1] - knots[i + 1]) != 0,
            (knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
            0
        )
    )

However, now I encounter a RecursionError:

RecursionError: maximum recursion depth exceeded in comparison

This recursion issue seems to stem from applying @jit, as it was not present before.

Share Improve this question asked Feb 6 at 15:23 KNIGHTKNIGHT 112 bronze badges New contributor KNIGHT is a new contributor to this site. Take care in asking for clarification, commenting, and answering. Check out our Code of Conduct.
Add a comment  | 

1 Answer 1

Reset to default 0

Unfortunately, you cannot use recursive approaches in JAX where the recursion is based on a traced condition. You'll either have to write your recursion using Python control flow with static conditions, or you'll have to rewrite it using a non-recursive approach.

In your case, the first option seems doable so long as degree is static at the call-site. In that case, you could fix your issue by redefining your first function this way:

from functools import partial

@partial(jax.jit, static_argnames=['degree'])
def bspline(t_values, knots, coefficients, degree):
  ...

Keep in mind though that JAX tracing will unroll all such recursion, so this may end up generating a long program that will lead to long compile times.

发布评论

评论列表(0)

  1. 暂无评论