In an effort to optimize an existing function used in an optimization algorithm by applying @jit
, I encountered some issues. When running the following function:
import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline
jax.config.update("jax_enable_x64", True)
@jax.jit
def bspline(t_values, knots, coefficients, degree):
"""
Generate a B-spline curve from given knots and coefficients.
Parameters:
- t_values: Array of parameter values where the spline is evaluated.
- knots: Knot vector as a 1D numpy array.
- coefficients: Control points as a 1D numpy array of shape (n,).
- degree: Degree of the B-spline (e.g., 3 for cubic B-splines).
Returns:
- A numpy array of shape (num_points,) representing the B-spline curve.
"""
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
if k == 0:
return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
else:
denom1 = knots[i + k] - knots[i]
denom2 = knots[i + k + 1] - knots[i + 1]
term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0
return term1 + term2
# Compute the B-spline curve points
curve_points = jnp.zeros(len(t_values))
for i in range(len(coefficients)):
v = basis_function(i, degree, t_values, knots)
curve_points = curve_points + v * coefficients[i]
return curve_points
I get the following error:
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
What I've Tried:
After consulting JAX's official documentation about this error (available here), I modified the basis_function
to avoid direct boolean checks:
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
return jnp.where(
k == 0,
jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
jnp.where(
(knots[i + k] - knots[i]) != 0,
(t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
0
) +
jnp.where(
(knots[i + k + 1] - knots[i + 1]) != 0,
(knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
0
)
)
However, now I encounter a RecursionError:
RecursionError: maximum recursion depth exceeded in comparison
This recursion issue seems to stem from applying @jit
, as it was not present before.
In an effort to optimize an existing function used in an optimization algorithm by applying @jit
, I encountered some issues. When running the following function:
import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline
jax.config.update("jax_enable_x64", True)
@jax.jit
def bspline(t_values, knots, coefficients, degree):
"""
Generate a B-spline curve from given knots and coefficients.
Parameters:
- t_values: Array of parameter values where the spline is evaluated.
- knots: Knot vector as a 1D numpy array.
- coefficients: Control points as a 1D numpy array of shape (n,).
- degree: Degree of the B-spline (e.g., 3 for cubic B-splines).
Returns:
- A numpy array of shape (num_points,) representing the B-spline curve.
"""
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
if k == 0:
return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
else:
denom1 = knots[i + k] - knots[i]
denom2 = knots[i + k + 1] - knots[i + 1]
term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0
return term1 + term2
# Compute the B-spline curve points
curve_points = jnp.zeros(len(t_values))
for i in range(len(coefficients)):
v = basis_function(i, degree, t_values, knots)
curve_points = curve_points + v * coefficients[i]
return curve_points
I get the following error:
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
What I've Tried:
After consulting JAX's official documentation about this error (available here), I modified the basis_function
to avoid direct boolean checks:
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
return jnp.where(
k == 0,
jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
jnp.where(
(knots[i + k] - knots[i]) != 0,
(t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
0
) +
jnp.where(
(knots[i + k + 1] - knots[i + 1]) != 0,
(knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
0
)
)
However, now I encounter a RecursionError:
RecursionError: maximum recursion depth exceeded in comparison
This recursion issue seems to stem from applying @jit
, as it was not present before.
1 Answer
Reset to default 0Unfortunately, you cannot use recursive approaches in JAX where the recursion is based on a traced condition. You'll either have to write your recursion using Python control flow with static conditions, or you'll have to rewrite it using a non-recursive approach.
In your case, the first option seems doable so long as degree
is static at the call-site. In that case, you could fix your issue by redefining your first function this way:
from functools import partial
@partial(jax.jit, static_argnames=['degree'])
def bspline(t_values, knots, coefficients, degree):
...
Keep in mind though that JAX tracing will unroll all such recursion, so this may end up generating a long program that will lead to long compile times.