Let's say I have a function foo(x, a, b)
and I want to find a specific one of its (potentially multiple) roots x0
, i.e. a value x0
such that foo(x0, a, b) == 0
. I know that for (a, b) == (0, 0)
the root I want is x0 == 0
and that the function changes continuously with a
and b
, so I can "follow" the root from (0, 0)
to the desired (a, b)
.
Here's an example function.
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
For (a, b) == (0, 0)
I want to the root at 0
, for (2, 0)
I want the one near 1.5
and for (2, 1)
I want the one near 2.2
.
Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy
or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?
Clarifications:
- Depending on
(a, b)
, the "correct" root may disappear (e.g. for(1, 3)
in the example). When this happens, returningnan
is the preferred behavior, though this is not super important. - By "fast" I mostly mean that many parameter sets
(a, b)
can be quickly solved, not just a single one. I will go on calculating the root for a lot of different parameters, e.g. for plotting and integrating over them.
Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.
import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()
# Semi-bodged solver for reference:
def _dfoo(x, a, b):
return -(1 + a) * np.cos(a + b - x) - 1
def _solve_fooroot(guess, a, b):
if np.isnan(guess):
return np.nan
# Determine limits for finding the root.
# Allow for slightly larger limits to account for numerical imprecision.
maxlim = 1.1 * (1 + a)
y0 = foo(guess, a, b)
if y0 == 0:
return guess
dy0 = _dfoo(guess, a, b)
estimate = -y0 / dy0
# Too small estimates are no good.
if np.abs(estimate) < 1e-2 * maxlim:
estimate = np.sign(estimate) * 1e-2 * maxlim
for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
if np.sign(foo(lim, a, b)) != np.sign(y0):
bracket = sorted([guess, lim])
break
else:
return np.nan
sol = root_scalar(foo, (a, b), bracket=bracket)
return sol.root
@functools.cache
def _fooroot(an, astep, bn, bstep):
if an == 0:
if bn == 0:
return 0
guessan, guessbn = an, bn - 1
else:
guessan, guessbn = an - 1, bn
# Avoid reaching maximum recursion depth.
for thisbn in range(0, guessbn, 100):
_fooroot(0, astep, thisbn, bstep)
for thisan in range(0, guessan, 100):
_fooroot(thisan, astep, guessbn, bstep)
guess = _fooroot(guessan, astep, guessbn, bstep)
return _solve_fooroot(guess, an * astep, bn * bstep)
@np.vectorize
def fooroot(a, b):
astep = (-1 if a < 0 else 1) * 1e-2
bstep = (-1 if b < 0 else 1) * 1e-2
guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
return _solve_fooroot(guess, a, b)
print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()
Let's say I have a function foo(x, a, b)
and I want to find a specific one of its (potentially multiple) roots x0
, i.e. a value x0
such that foo(x0, a, b) == 0
. I know that for (a, b) == (0, 0)
the root I want is x0 == 0
and that the function changes continuously with a
and b
, so I can "follow" the root from (0, 0)
to the desired (a, b)
.
Here's an example function.
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
For (a, b) == (0, 0)
I want to the root at 0
, for (2, 0)
I want the one near 1.5
and for (2, 1)
I want the one near 2.2
.
Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy
or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?
Clarifications:
- Depending on
(a, b)
, the "correct" root may disappear (e.g. for(1, 3)
in the example). When this happens, returningnan
is the preferred behavior, though this is not super important. - By "fast" I mostly mean that many parameter sets
(a, b)
can be quickly solved, not just a single one. I will go on calculating the root for a lot of different parameters, e.g. for plotting and integrating over them.
Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.
import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()
# Semi-bodged solver for reference:
def _dfoo(x, a, b):
return -(1 + a) * np.cos(a + b - x) - 1
def _solve_fooroot(guess, a, b):
if np.isnan(guess):
return np.nan
# Determine limits for finding the root.
# Allow for slightly larger limits to account for numerical imprecision.
maxlim = 1.1 * (1 + a)
y0 = foo(guess, a, b)
if y0 == 0:
return guess
dy0 = _dfoo(guess, a, b)
estimate = -y0 / dy0
# Too small estimates are no good.
if np.abs(estimate) < 1e-2 * maxlim:
estimate = np.sign(estimate) * 1e-2 * maxlim
for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
if np.sign(foo(lim, a, b)) != np.sign(y0):
bracket = sorted([guess, lim])
break
else:
return np.nan
sol = root_scalar(foo, (a, b), bracket=bracket)
return sol.root
@functools.cache
def _fooroot(an, astep, bn, bstep):
if an == 0:
if bn == 0:
return 0
guessan, guessbn = an, bn - 1
else:
guessan, guessbn = an - 1, bn
# Avoid reaching maximum recursion depth.
for thisbn in range(0, guessbn, 100):
_fooroot(0, astep, thisbn, bstep)
for thisan in range(0, guessan, 100):
_fooroot(thisan, astep, guessbn, bstep)
guess = _fooroot(guessan, astep, guessbn, bstep)
return _solve_fooroot(guess, an * astep, bn * bstep)
@np.vectorize
def fooroot(a, b):
astep = (-1 if a < 0 else 1) * 1e-2
bstep = (-1 if b < 0 else 1) * 1e-2
guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
return _solve_fooroot(guess, a, b)
print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()
Share
Improve this question
edited Apr 1 at 9:13
schtandard
asked Mar 31 at 10:14
schtandardschtandard
5315 silver badges24 bronze badges
5
|
2 Answers
Reset to default 0Get rid of your @cache
and @vectorize
; neither is likely to help you for the following and they're just noise. (If they're needed for outer code, you haven't shown that outer code, so the point stands.)
Do keep using Scipy's root-finding iteratively, but beyond that your procedure should look pretty different. I propose:
Get your initial estimate x0, a0, b0. Then in a loop:
- Infer by analytic integration the paraboloid intersecting the current point whose first and second derivatives with respect to a and b match those of f.
- Increment a and b at their fixed step size. If the function is smooth and parabolic step estimation works well then this step may be somewhat large; but if you're writing this for a generic routine that can take any function then it must be parametric.
- Call
root_scalar
with your new estimate, new a and b, passing analyticfprime
andfprime2
, probably using Halley's Method, and having relaxed tolerances that are parametric and appropriate to the function.
Of course that's the ideal case, but in practice Scipy limitations mean that the vectorised methods cannot easily use second-order gradients. That in turn means that Halley is unavailable, but even linear Jacobian steps work well.
The following demonstration shows a fully-vectorised path traversal, with two independent start points and stop points in a, b space. This can be arbitrarily extended to as many consecutive paths as you want.
import logging
import typing
import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad
class Trivariate(typing.Protocol):
def __call__(
self,
x: np.ndarray, a: np.ndarray, b: np.ndarray,
) -> np.ndarray: ...
def foo(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.sin(a + b - x) - x
def dfoo_dx(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (-1 - a)*np.cos(a + b - x) - 1
def dfoo_da(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.cos(a + b - x) + np.sin(a + b - x)
def dfoo_db(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.cos(a + b - x)
def follow_root(
fun: Trivariate, dfdx: Trivariate, dfda: Trivariate, dfdb: Trivariate,
a0: ArrayLike, a1: ArrayLike,
b0: ArrayLike, b1: ArrayLike,
x0: ArrayLike,
steps: int = 10,
method: str = 'hybr',
follow_tol: float = 1e-2, follow_reltol: float = 1e-2,
polish_tol: float = 1e-12, polish_reltol: float = 1e-12,
) -> np.ndarray:
def dfdx_sparse(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return np.diag(dfdx(x, a, b))
x_est = np.asarray(x0)
ab0 = np.array((a0, b0))
ab1 = np.array((a1, b1))
# (number of steps, ab, dimensions of a0...) = (n-1, 2, ...)
ab = np.linspace(start=ab0, stop=ab1, num=steps)
da = ab[1, 0] - ab[0, 0]
db = ab[1, 1] - ab[0, 1]
for i, ((ai, bi), (ai1, bi1)) in enumerate(zip(ab[:-1], ab[1:])):
dfdxi = dfdx(x_est, ai, bi)
dxda = dfda(x_est, ai, bi)/dfdxi
dxdb = dfdb(x_est, ai, bi)/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = -dxda*da - dxdb*db
last = i == len(ab) - 2
result = root(
fun=fun, args=(ai1, bi1), jac=dfdx_sparse, x0=x_est + step, method=method,
tol=polish_tol if last else follow_tol,
options={
'xtol': polish_reltol if last else follow_reltol,
'col_deriv': True,
},
)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
logging.debug('#%d x%d %s +%s ~ %s = %s', i, result.nfev, x_est, step, x_est + step, result.x)
x_est = result.x
return x_est
def main() -> None:
# Don't do this in production!
logging.getLogger().setLevel(logging.DEBUG)
err = check_grad(
lambda x: foo(x, np.array((0, 2)), np.array((0, 0))),
lambda x: np.diag(dfoo_dx(x, np.array((0, 2)), np.array((0, 0)))),
(0.1, 1.5), # x0
)
assert err < 1e-7
err = check_grad(
lambda a: foo(np.array((0.1, 1.5)), a, np.array((0, 0))),
lambda a: np.diag(dfoo_da(np.array((0.1, 1.5)), a, np.array((0, 0)))),
(0, 2), # a0
)
assert err < 1e-7
err = check_grad(
lambda b: foo(np.array((0.1, 1.5)), np.array((0, 2)), b),
lambda b: np.diag(dfoo_db(np.array((0.1, 1.5)), np.array((0, 2)), b)),
(0, 0), # b0
)
assert err < 1e-7
follow_root(
fun=foo, dfdx=dfoo_dx, dfda=dfoo_da, dfdb=dfoo_db,
a0=(0, 2), a1=(1.7, 2),
b0=(0, 0), b1=(0 , 1),
x0=(0, 1.5),
)
if __name__ == '__main__':
main()
DEBUG:root:#0 x5 [0. 1.5] +[0.09444444 0.08052514] ~ [0.09444444 1.58052514] = [0.1025362 1.56306428]
DEBUG:root:#1 x4 [0.1025362 1.56306428] +[0.10987707 0.07990566] ~ [0.21241326 1.64296994] = [0.21851137 1.64274921]
DEBUG:root:#2 x4 [0.21851137 1.64274921] +[0.12155443 0.07945781] ~ [0.3400658 1.72220702] = [0.34478035 1.72196514]
DEBUG:root:#3 x4 [0.34478035 1.72196514] +[0.13061949 0.0789664 ] ~ [0.47539984 1.80093153] = [0.47912896 1.80066588]
DEBUG:root:#4 x4 [0.47912896 1.80066588] +[0.13781338 0.07842657] ~ [0.61694234 1.87909245] = [0.61994978 1.87880026]
DEBUG:root:#5 x4 [0.61994978 1.87880026] +[0.14363144 0.07783262] ~ [0.76358122 1.95663288] = [0.76604746 1.95631084]
DEBUG:root:#6 x4 [0.76604746 1.95631084] +[0.14841417 0.07717773] ~ [0.91446163 2.03348857] = [0.9165135 2.03313271]
DEBUG:root:#7 x4 [0.9165135 2.03313271] +[0.15240176 0.07645371] ~ [1.06891526 2.10958643] = [1.07064402 2.10919196]
DEBUG:root:#8 x7 [1.07064402 2.10919196] +[0.15576766 0.07565065] ~ [1.22641168 2.1848426 ] = [1.22788398 2.18440359]
This works fine with a reduced number of steps; with only four steps:
DEBUG:root:#0 x5 [0. 1.5] +[0.28333333 0.24157542] ~ [0.28333333 1.74157542] = [0.34477692 1.72196632]
DEBUG:root:#1 x5 [0.34477692 1.72196632] +[0.39185914 0.23689925] ~ [0.73663606 1.95886557] = [0.76604622 1.95631082]
DEBUG:root:#2 x8 [0.76604622 1.95631082] +[0.44524268 0.23153319] ~ [1.2112889 2.18784401] = [1.22788398 2.18440359]
Grid following
Keep the main idea (and its gradients); build a row-wise output:
import typing
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad
class Trivariate(typing.Protocol):
def __call__(
self, x: np.ndarray, ab: np.ndarray,
) -> np.ndarray: ...
def foo(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
return (1 + a)*np.sin(a + b - x) - x
def dfoo_dx(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
return (-1 - a)*np.cos(a + b - x) - 1
def dfoo_dab(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
abx = a + b - x
a1cos = (1 + a)*np.cos(abx)
return np.stack((a1cos + np.sin(abx), a1cos))
def next_roots(
baked_root, dfdx: Trivariate, dfdab: Trivariate,
dab: np.ndarray,
ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
) -> np.ndarray:
dfdxi = dfdx(x0, ab0)
dxdab = dfdab(x0, ab0)/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = (-dab) @ dxdab
result = baked_root(args=ab1, x0=x0 + step)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
return result.x
def roots_2d(
fun: Trivariate, dfdx: Trivariate, dfdab: Trivariate,
a0: float, a1: float,
b0: float, b1: float,
centre_estimate: float,
resolution: int = 201,
method: str = 'hybr',
tol: float = 1e-2, reltol: float = 1e-2,
dtype: np.dtype = np.float32,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def dfdx_sparse(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
return np.diag(dfdx(x, ab))
baked_root = partial(
root, fun=fun, jac=dfdx_sparse, method=method, tol=tol,
options={'xtol': reltol, 'col_deriv': True},
)
baked_next = partial(
next_roots, baked_root=baked_root, dfdx=dfdx, dfdab=dfdab,
)
aser = np.linspace(start=a0, stop=a1, num=resolution, dtype=dtype)
bser = np.linspace(start=b0, stop=b1, num=resolution, dtype=dtype)
da = aser[1] - aser[0]
db = bser[1] - bser[0]
aa, bb = np.meshgrid(aser, bser)
aabb = np.stack((aa, bb)) # (2, 201, 201): (ab, b index, a index)
out = np.empty_like(aa)
# Centre point, the only one for which we rely on an estimate from the caller
imid = resolution//2
out[imid, imid] = baked_root(args=aabb[:, imid, imid], x0=centre_estimate).x.squeeze()
# Centre to right, scalars
dar = np.array((da, 0), dtype=da.dtype)
for j in range(imid + 1, resolution):
out[imid, j] = baked_next(
dab=dar, ab0=aabb[:, imid, j-1], ab1=aabb[:, imid, j], x0=out[imid, j-1],
).squeeze()
# Centre to left, scalars
dal = -dar
for j in range(imid - 1, -1, -1):
out[imid, j] = baked_next(
dab=dal, ab0=aabb[:, imid, j+1], ab1=aabb[:, imid, j], x0=out[imid, j+1],
).squeeze()
# Down rows
dbd = np.array((0, db), dtype=db.dtype)
for i in range(imid + 1, resolution):
out[i] = baked_next(
dab=dbd, ab0=aabb[:, i-1], ab1=aabb[:, i], x0=out[i-1],
)
# Up rows
dbu = -dbd
for i in range(imid - 1, -1, -1):
out[i] = baked_next(
dab=dbu, ab0=aabb[:, i+1], ab1=aabb[:, i], x0=out[i+1],
)
return aa, bb, out
def plot(aa: np.ndarray, bb: np.ndarray, x: np.ndarray) -> plt.Figure:
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel('a')
ax.set_ylabel('b')
mesh = ax.pcolormesh(aa, bb, x, vmin=-3, vmax=3)
fig.colorbar(mesh, label='root')
return fig
def main() -> None:
x0 = np.array((0.1, 1.5))
ab0 = np.array([(0.3, 2), (0.1, 0.2)])
err = check_grad(
partial(foo, ab=ab0),
lambda x: np.diag(dfoo_dx(x, ab0)),
x0,
)
assert err < 1e-7
# err = check_grad(
# partial(foo, x0),
# lambda ab: dfoo_dab(x0, ab),
# ab0,
# )
# assert err < 1e-7
aa, bb, x = roots_2d(
fun=foo, dfdx=dfoo_dx, dfdab=dfoo_dab,
a0=-1, a1=1,
b0=-3, b1=3,
centre_estimate=0,
)
plot(aa, bb, x)
plt.show()
if __name__ == '__main__':
main()
Executes in a second or two:
To take care of disappearing roots, there are no perfect solutions. Either you need to write a flood fill algorithm, which is complicated; or you can just do a simple heuristic like
def next_roots(
baked_root, dfdx: Trivariate, dfdab: Trivariate,
dab: np.ndarray,
ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
error_bound: float = 1e-2,
) -> np.ndarray:
input_mask = np.isfinite(x0)
x0_masked = x0[input_mask]
dfdxi = dfdx(x0_masked, ab0[:, input_mask])
dxdab = dfdab(x0_masked, ab0[:, input_mask])/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = (-dab) @ dxdab
xest = x0_masked + step
result = baked_root(args=ab1[:, input_mask], x0=xest)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
xsol = result.x
est_error = np.abs(xsol - xest)
xsol[est_error > error_bound] = np.nan
xnew = np.full_like(x0, fill_value=np.nan)
xnew[input_mask] = xsol
return xnew
If I correctly understood your question, the readily implemented solution you are looking for is fsolve
from scipy.optimize
, while a relatively easy to implement is the bisection method using numba which can guarantee better overall performance, given that it is just-in-time compiled.
import time
import numpy as np
from scipy.optimize import root_scalar, fsolve
import numba as nb
import functools
from matplotlib import pyplot as plt
from types import FunctionType
# --- Function definitions ---
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
def _dfoo(x, a, b):
return -(1 + a) * np.cos(a + b - x) - 1
# Original (cached) solver for reference:
@functools.cache
def _fooroot(an, astep, bn, bstep):
if an == 0:
if bn == 0:
return 0
guessan, guessbn = an, bn - 1
else:
guessan, guessbn = an - 1, bn
# Avoid reaching maximum recursion depth.
for thisbn in range(0, guessbn, 100):
_fooroot(0, astep, thisbn, bstep)
for thisan in range(0, guessan, 100):
_fooroot(thisan, astep, guessbn, bstep)
guess = _fooroot(guessan, astep, guessbn, bstep)
return _solve_fooroot(guess, an * astep, bn * bstep)
def _solve_fooroot(guess, a, b):
maxlim = 1.1 * (1 + a)
y0 = foo(guess, a, b)
if y0 == 0:
return guess
dy0 = _dfoo(guess, a, b)
estimate = -y0 / dy0
if np.abs(estimate) < 1e-2 * maxlim:
estimate = np.sign(estimate) * 1e-2 * maxlim
for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
if np.sign(foo(lim, a, b)) != np.sign(y0):
bracket = sorted([guess, lim])
break
else:
return np.nan
sol = root_scalar(foo, bracket=bracket, args=(a, b))
return sol.root
@np.vectorize(excluded=[2, 3, 4, 5])
@functools.cache
def fooroot(a, b):
astep = (-1 if a < 0 else 1) * 1e-2
bstep = (-1 if b < 0 else 1) * 1e-2
guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
return _solve_fooroot(guess, a, b)
# --- SciPy fsolve implementation ---
def scipy_fsolve(fun, x0, tol, mxiter, *args):
sol, infodict, ier, mesg = fsolve(fun, x0, args=args, xtol=tol, maxfev=mxiter, full_output=True)
if ier != 1:
raise RuntimeError(f"fsolve did not converge: {mesg}")
return sol[0]
# --- Numba-based solver ---
def compile_specialized_bisect(fun):
"""Returns a compiled bisection implementation for ``fun``."""
compiled_f = nb.njit()(fun)
def python_bisect(a, b, tol, mxiter, *args):
"""Python implementation of bisection method."""
its = 0
fa = compiled_f(a, *args)
fb = compiled_f(b, *args)
if np.abs(fa) < tol:
return a
if np.abs(fb) < tol:
return b
c = (a + b) / 2.0
fc = compiled_f(c, *args)
while np.abs(fc) > tol and its < mxiter:
its += 1
if fa * fc < 0:
b = c
fb = fc
else:
a = c
fa = fc
c = (a + b) / 2.0
fc = compiled_f(c, *args)
return c
return nb.njit()(python_bisect)
def numba_bisect(fun, a_val, b_val, tol, mxiter, *args):
if isinstance(fun, FunctionType):
jit_bisect = compile_specialized_bisect(fun)
return jit_bisect(a_val, b_val, tol, mxiter, *args)
return fun(a_val, b_val, tol, mxiter, *args)
# Compile specialized functions to be used in numba_bisect below:
_jit_bisect_foo = compile_specialized_bisect(foo)
dummy_wakeup_call_jit = _jit_bisect_foo(-np.pi, np.pi, 1e-6, 1000, 0, 0)
# --- Timing comparison ---
params = [(0, 0), (2, 0), (2, 1)]
tol = 1e-6
mxiter = 1000
print("Original method (fooroot):")
for a, b in params:
start = time.time()
root_val = fooroot(a, b)
elapsed = time.time() - start
print(f"fooroot({a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")
print("\nNumba bisection method:")
for a, b in params:
start = time.time()
# root_val = _jit_bisect(-np.pi, np.pi, tol, mxiter, a, b)
root_val = numba_bisect(_jit_bisect_foo, -np.pi, np.pi, tol, mxiter, a, b)
elapsed = time.time() - start
print(f"numba_bisect(foo, {a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")
print("\nSciPy fsolve method:")
for a, b in params:
start = time.time()
root_val = scipy_fsolve(foo, a, tol, mxiter, a, b)
elapsed = time.time() - start
print(f"fsolve(foo, {a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")
# --- Plotting ---
# Plot the function foo for each parameter set along with its zero.
x_vals = np.linspace(-np.pi, np.pi, 400)
plt.figure(figsize=(10, 6))
plt.grid(True)
plt.xlabel("x")
plt.ylabel("foo(x, a, b)")
colors = ['blue', 'green', 'red']
labels = ["(a, b) = (0, 0)", "(a, b) = (2, 0)", "(a, b) = (2, 1)"]
for idx, (a, b) in enumerate(params):
y_vals = foo(x_vals, a, b)
plt.plot(x_vals, y_vals, color=colors[idx], label=labels[idx])
# Compute the zero using fooroot
fooroot_zero = fooroot(a, b)
plt.plot(fooroot_zero, foo(fooroot_zero, a, b), 'kx', markersize=8)
# Compute the zero using numba
bisect_zero = numba_bisect(_jit_bisect_foo, -np.pi, np.pi, tol, mxiter, a, b)
plt.plot(bisect_zero, foo(bisect_zero, a, b), 'ko', markersize=8)
plt.text(bisect_zero, foo(bisect_zero, a, b), f" {bisect_zero:.2f}",
fontsize=9, color=colors[idx])
plt.legend()
plt.title("Function curves and computed zeros")
plt.show()
This is the output on a Colab notebook:
Original method (fooroot):
fooroot(0, 0) = 0.000000 in 0.000194 seconds
fooroot(2, 0) = 1.482951 in 0.024140 seconds
fooroot(2, 1) = 2.184404 in 0.033811 seconds
Numba bisection method:
numba_bisect(foo, 0, 0) = 0.000000 in 0.000024 seconds
numba_bisect(foo, 2, 0) = 1.482951 in 0.000005 seconds
numba_bisect(foo, 2, 1) = 2.184404 in 0.000004 seconds
SciPy fsolve method:
fsolve(foo, 0, 0) = 0.000000 in 0.000210 seconds
fsolve(foo, 2, 0) = 1.482951 in 0.000129 seconds
fsolve(foo, 2, 1) = 2.184404 in 0.000132 seconds
(Most likely this is not what you are asking, but I will still leave the answer because there are not many implementations of numba-based solvers for parametric functions on stackoverflow).
root_scalar
iteratively but with extremely different parameters. Passfprime
andfprime2
, and relax the tolerances, and do not use the defaultmethod
. – Reinderien Commented Mar 31 at 11:56