最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Find correct root of parametrized function given solution for one set of parameters - Stack Overflow

programmeradmin4浏览0评论

Let's say I have a function foo(x, a, b) and I want to find a specific one of its (potentially multiple) roots x0, i.e. a value x0 such that foo(x0, a, b) == 0. I know that for (a, b) == (0, 0) the root I want is x0 == 0 and that the function changes continuously with a and b, so I can "follow" the root from (0, 0) to the desired (a, b).

Here's an example function.

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

For (a, b) == (0, 0) I want to the root at 0, for (2, 0) I want the one near 1.5 and for (2, 1) I want the one near 2.2.

Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?


Clarifications:

  • Depending on (a, b), the "correct" root may disappear (e.g. for (1, 3) in the example). When this happens, returning nan is the preferred behavior, though this is not super important.
  • By "fast" I mostly mean that many parameter sets (a, b) can be quickly solved, not just a single one. I will go on calculating the root for a lot of different parameters, e.g. for plotting and integrating over them.

Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.

import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()

# Semi-bodged solver for reference:

def _dfoo(x, a, b):
    return -(1 + a) * np.cos(a + b - x) - 1

def _solve_fooroot(guess, a, b):
    if np.isnan(guess):
        return np.nan
    # Determine limits for finding the root.
    # Allow for slightly larger limits to account for numerical imprecision.
    maxlim = 1.1 * (1 + a)
    y0 = foo(guess, a, b)
    if y0 == 0:
        return guess
    dy0 = _dfoo(guess, a, b)
    estimate = -y0 / dy0
    # Too small estimates are no good.
    if np.abs(estimate) < 1e-2 * maxlim:
        estimate = np.sign(estimate) * 1e-2 * maxlim
    for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
        if np.sign(foo(lim, a, b)) != np.sign(y0):
            bracket = sorted([guess, lim])
            break
    else:
        return np.nan
    sol = root_scalar(foo, (a, b), bracket=bracket)
    return sol.root

@functools.cache
def _fooroot(an, astep, bn, bstep):
    if an == 0:
        if bn == 0:
            return 0
        guessan, guessbn = an, bn - 1
    else:
        guessan, guessbn = an - 1, bn
    # Avoid reaching maximum recursion depth.
    for thisbn in range(0, guessbn, 100):
        _fooroot(0, astep, thisbn, bstep)
    for thisan in range(0, guessan, 100):
        _fooroot(thisan, astep, guessbn, bstep)
    guess = _fooroot(guessan, astep, guessbn, bstep)
    return _solve_fooroot(guess, an * astep, bn * bstep)

@np.vectorize
def fooroot(a, b):
    astep = (-1 if a < 0 else 1) * 1e-2
    bstep = (-1 if b < 0 else 1) * 1e-2
    guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
    return _solve_fooroot(guess, a, b)

print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
    ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()

Let's say I have a function foo(x, a, b) and I want to find a specific one of its (potentially multiple) roots x0, i.e. a value x0 such that foo(x0, a, b) == 0. I know that for (a, b) == (0, 0) the root I want is x0 == 0 and that the function changes continuously with a and b, so I can "follow" the root from (0, 0) to the desired (a, b).

Here's an example function.

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

For (a, b) == (0, 0) I want to the root at 0, for (2, 0) I want the one near 1.5 and for (2, 1) I want the one near 2.2.

Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?


Clarifications:

  • Depending on (a, b), the "correct" root may disappear (e.g. for (1, 3) in the example). When this happens, returning nan is the preferred behavior, though this is not super important.
  • By "fast" I mostly mean that many parameter sets (a, b) can be quickly solved, not just a single one. I will go on calculating the root for a lot of different parameters, e.g. for plotting and integrating over them.

Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.

import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()

# Semi-bodged solver for reference:

def _dfoo(x, a, b):
    return -(1 + a) * np.cos(a + b - x) - 1

def _solve_fooroot(guess, a, b):
    if np.isnan(guess):
        return np.nan
    # Determine limits for finding the root.
    # Allow for slightly larger limits to account for numerical imprecision.
    maxlim = 1.1 * (1 + a)
    y0 = foo(guess, a, b)
    if y0 == 0:
        return guess
    dy0 = _dfoo(guess, a, b)
    estimate = -y0 / dy0
    # Too small estimates are no good.
    if np.abs(estimate) < 1e-2 * maxlim:
        estimate = np.sign(estimate) * 1e-2 * maxlim
    for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
        if np.sign(foo(lim, a, b)) != np.sign(y0):
            bracket = sorted([guess, lim])
            break
    else:
        return np.nan
    sol = root_scalar(foo, (a, b), bracket=bracket)
    return sol.root

@functools.cache
def _fooroot(an, astep, bn, bstep):
    if an == 0:
        if bn == 0:
            return 0
        guessan, guessbn = an, bn - 1
    else:
        guessan, guessbn = an - 1, bn
    # Avoid reaching maximum recursion depth.
    for thisbn in range(0, guessbn, 100):
        _fooroot(0, astep, thisbn, bstep)
    for thisan in range(0, guessan, 100):
        _fooroot(thisan, astep, guessbn, bstep)
    guess = _fooroot(guessan, astep, guessbn, bstep)
    return _solve_fooroot(guess, an * astep, bn * bstep)

@np.vectorize
def fooroot(a, b):
    astep = (-1 if a < 0 else 1) * 1e-2
    bstep = (-1 if b < 0 else 1) * 1e-2
    guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
    return _solve_fooroot(guess, a, b)

print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
    ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()

Share Improve this question edited Apr 1 at 9:13 schtandard asked Mar 31 at 10:14 schtandardschtandard 5315 silver badges24 bronze badges 5
  • It doesn't at all surprise me that there's no generic function for this; the devil is in the details. Is the function differentiable - ideally double-differentiable? Writing a good solution for "any function" is effectively impossible. Can you show your actual function? – Reinderien Commented Mar 31 at 11:51
  • With not at all enough information about the true function, here is how I would approach it: get an initial solution for your correct root, then do use root_scalar iteratively but with extremely different parameters. Pass fprime and fprime2, and relax the tolerances, and do not use the default method. – Reinderien Commented Mar 31 at 11:56
  • Another reason you should have your derivatives: you should be able to estimate given some a and b step size and the Jacobian in that parameter space where the new root should be. – Reinderien Commented Mar 31 at 12:00
  • @Reinderien Yes, the function is double-differentiable. Currently, my function is pretty similar to the on in my question (with a bunch of extra parameters that are not relevant for finding the root), but it may change. Solutions that only work with (arbitrarily) differentiable functions are absolutely fine for me. (Indeed, I did use the derivative to estimate how the root shifts in my code.) – schtandard Commented Mar 31 at 12:35
  • @Reinderien I'm not sure I can fully follow your suggested approach. Looking at extremely different parameters right away incur losing track of the correct root, no? Otherwise it seems quite similar to what I did, albeit with a more sophisticated idea of how to estimate how the root shifts when stepping the parameters, maybe? – schtandard Commented Mar 31 at 12:37
Add a comment  | 

2 Answers 2

Reset to default 0

Get rid of your @cache and @vectorize; neither is likely to help you for the following and they're just noise. (If they're needed for outer code, you haven't shown that outer code, so the point stands.)

Do keep using Scipy's root-finding iteratively, but beyond that your procedure should look pretty different. I propose:

Get your initial estimate x0, a0, b0. Then in a loop:

  1. Infer by analytic integration the paraboloid intersecting the current point whose first and second derivatives with respect to a and b match those of f.
  2. Increment a and b at their fixed step size. If the function is smooth and parabolic step estimation works well then this step may be somewhat large; but if you're writing this for a generic routine that can take any function then it must be parametric.
  3. Call root_scalar with your new estimate, new a and b, passing analytic fprime and fprime2, probably using Halley's Method, and having relaxed tolerances that are parametric and appropriate to the function.

Of course that's the ideal case, but in practice Scipy limitations mean that the vectorised methods cannot easily use second-order gradients. That in turn means that Halley is unavailable, but even linear Jacobian steps work well.

The following demonstration shows a fully-vectorised path traversal, with two independent start points and stop points in a, b space. This can be arbitrarily extended to as many consecutive paths as you want.

import logging
import typing

import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad


class Trivariate(typing.Protocol):
    def __call__(
        self,
        x: np.ndarray, a: np.ndarray, b: np.ndarray,
    ) -> np.ndarray: ...


def foo(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return (1 + a)*np.sin(a + b - x) - x


def dfoo_dx(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return (-1 - a)*np.cos(a + b - x) - 1


def dfoo_da(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return (1 + a)*np.cos(a + b - x) + np.sin(a + b - x)


def dfoo_db(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return (1 + a)*np.cos(a + b - x)


def follow_root(
    fun: Trivariate, dfdx: Trivariate, dfda: Trivariate, dfdb: Trivariate,
    a0: ArrayLike, a1: ArrayLike,
    b0: ArrayLike, b1: ArrayLike,
    x0: ArrayLike,
    steps: int = 10,
    method: str = 'hybr',
    follow_tol: float = 1e-2, follow_reltol: float = 1e-2,
    polish_tol: float = 1e-12, polish_reltol: float = 1e-12,
) -> np.ndarray:
    def dfdx_sparse(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.diag(dfdx(x, a, b))

    x_est = np.asarray(x0)
    ab0 = np.array((a0, b0))
    ab1 = np.array((a1, b1))

    # (number of steps, ab, dimensions of a0...) = (n-1, 2, ...)
    ab = np.linspace(start=ab0, stop=ab1, num=steps)
    da = ab[1, 0] - ab[0, 0]
    db = ab[1, 1] - ab[0, 1]

    for i, ((ai, bi), (ai1, bi1)) in enumerate(zip(ab[:-1], ab[1:])):
        dfdxi = dfdx(x_est, ai, bi)
        dxda = dfda(x_est, ai, bi)/dfdxi
        dxdb = dfdb(x_est, ai, bi)/dfdxi
        # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
        step = -dxda*da - dxdb*db
        last = i == len(ab) - 2

        result = root(
            fun=fun, args=(ai1, bi1), jac=dfdx_sparse, x0=x_est + step, method=method,
            tol=polish_tol if last else follow_tol,
            options={
                'xtol': polish_reltol if last else follow_reltol,
                'col_deriv': True,
            },
        )
        if not result.success:
            raise ValueError('Root finding failed: ' + result.message)

        logging.debug('#%d x%d %s +%s ~ %s = %s', i, result.nfev, x_est, step, x_est + step, result.x)
        x_est = result.x

    return x_est


def main() -> None:
    # Don't do this in production!
    logging.getLogger().setLevel(logging.DEBUG)

    err = check_grad(
        lambda x: foo(x, np.array((0, 2)), np.array((0, 0))),
        lambda x: np.diag(dfoo_dx(x, np.array((0, 2)), np.array((0, 0)))),
        (0.1, 1.5),  # x0
    )
    assert err < 1e-7

    err = check_grad(
        lambda a: foo(np.array((0.1, 1.5)), a, np.array((0, 0))),
        lambda a: np.diag(dfoo_da(np.array((0.1, 1.5)), a, np.array((0, 0)))),
        (0, 2),  # a0
    )
    assert err < 1e-7

    err = check_grad(
        lambda b: foo(np.array((0.1, 1.5)), np.array((0, 2)), b),
        lambda b: np.diag(dfoo_db(np.array((0.1, 1.5)), np.array((0, 2)), b)),
        (0, 0),  # b0
    )
    assert err < 1e-7

    follow_root(
        fun=foo, dfdx=dfoo_dx, dfda=dfoo_da, dfdb=dfoo_db,
        a0=(0, 2), a1=(1.7, 2),
        b0=(0, 0), b1=(0  , 1),
        x0=(0, 1.5),
    )


if __name__ == '__main__':
    main()
DEBUG:root:#0 x5 [0.  1.5] +[0.09444444 0.08052514] ~ [0.09444444 1.58052514] = [0.1025362  1.56306428]
DEBUG:root:#1 x4 [0.1025362  1.56306428] +[0.10987707 0.07990566] ~ [0.21241326 1.64296994] = [0.21851137 1.64274921]
DEBUG:root:#2 x4 [0.21851137 1.64274921] +[0.12155443 0.07945781] ~ [0.3400658  1.72220702] = [0.34478035 1.72196514]
DEBUG:root:#3 x4 [0.34478035 1.72196514] +[0.13061949 0.0789664 ] ~ [0.47539984 1.80093153] = [0.47912896 1.80066588]
DEBUG:root:#4 x4 [0.47912896 1.80066588] +[0.13781338 0.07842657] ~ [0.61694234 1.87909245] = [0.61994978 1.87880026]
DEBUG:root:#5 x4 [0.61994978 1.87880026] +[0.14363144 0.07783262] ~ [0.76358122 1.95663288] = [0.76604746 1.95631084]
DEBUG:root:#6 x4 [0.76604746 1.95631084] +[0.14841417 0.07717773] ~ [0.91446163 2.03348857] = [0.9165135  2.03313271]
DEBUG:root:#7 x4 [0.9165135  2.03313271] +[0.15240176 0.07645371] ~ [1.06891526 2.10958643] = [1.07064402 2.10919196]
DEBUG:root:#8 x7 [1.07064402 2.10919196] +[0.15576766 0.07565065] ~ [1.22641168 2.1848426 ] = [1.22788398 2.18440359]

This works fine with a reduced number of steps; with only four steps:

DEBUG:root:#0 x5 [0.  1.5] +[0.28333333 0.24157542] ~ [0.28333333 1.74157542] = [0.34477692 1.72196632]
DEBUG:root:#1 x5 [0.34477692 1.72196632] +[0.39185914 0.23689925] ~ [0.73663606 1.95886557] = [0.76604622 1.95631082]
DEBUG:root:#2 x8 [0.76604622 1.95631082] +[0.44524268 0.23153319] ~ [1.2112889  2.18784401] = [1.22788398 2.18440359]

Grid following

Keep the main idea (and its gradients); build a row-wise output:

import typing
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad


class Trivariate(typing.Protocol):
    def __call__(
        self, x: np.ndarray, ab: np.ndarray,
    ) -> np.ndarray: ...


def foo(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
    a, b = ab
    return (1 + a)*np.sin(a + b - x) - x


def dfoo_dx(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
    a, b = ab
    return (-1 - a)*np.cos(a + b - x) - 1


def dfoo_dab(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
    a, b = ab
    abx = a + b - x
    a1cos = (1 + a)*np.cos(abx)
    return np.stack((a1cos + np.sin(abx), a1cos))


def next_roots(
    baked_root, dfdx: Trivariate, dfdab: Trivariate,
    dab: np.ndarray,
    ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
) -> np.ndarray:
    dfdxi = dfdx(x0, ab0)
    dxdab = dfdab(x0, ab0)/dfdxi
    # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
    step = (-dab) @ dxdab

    result = baked_root(args=ab1, x0=x0 + step)
    if not result.success:
        raise ValueError('Root finding failed: ' + result.message)

    return result.x


def roots_2d(
    fun: Trivariate, dfdx: Trivariate, dfdab: Trivariate,
    a0: float, a1: float,
    b0: float, b1: float,
    centre_estimate: float,
    resolution: int = 201,
    method: str = 'hybr',
    tol: float = 1e-2, reltol: float = 1e-2,
    dtype: np.dtype = np.float32,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    def dfdx_sparse(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
        return np.diag(dfdx(x, ab))

    baked_root = partial(
        root, fun=fun, jac=dfdx_sparse, method=method, tol=tol,
        options={'xtol': reltol, 'col_deriv': True},
    )
    baked_next = partial(
        next_roots, baked_root=baked_root, dfdx=dfdx, dfdab=dfdab,
    )

    aser = np.linspace(start=a0, stop=a1, num=resolution, dtype=dtype)
    bser = np.linspace(start=b0, stop=b1, num=resolution, dtype=dtype)
    da = aser[1] - aser[0]
    db = bser[1] - bser[0]
    aa, bb = np.meshgrid(aser, bser)
    aabb = np.stack((aa, bb))  # (2, 201, 201): (ab, b index, a index)
    out = np.empty_like(aa)

    # Centre point, the only one for which we rely on an estimate from the caller
    imid = resolution//2
    out[imid, imid] = baked_root(args=aabb[:, imid, imid], x0=centre_estimate).x.squeeze()

    # Centre to right, scalars
    dar = np.array((da, 0), dtype=da.dtype)
    for j in range(imid + 1, resolution):
        out[imid, j] = baked_next(
            dab=dar, ab0=aabb[:, imid, j-1], ab1=aabb[:, imid, j], x0=out[imid, j-1],
        ).squeeze()

    # Centre to left, scalars
    dal = -dar
    for j in range(imid - 1, -1, -1):
        out[imid, j] = baked_next(
            dab=dal, ab0=aabb[:, imid, j+1], ab1=aabb[:, imid, j], x0=out[imid, j+1],
        ).squeeze()

    # Down rows
    dbd = np.array((0, db), dtype=db.dtype)
    for i in range(imid + 1, resolution):
        out[i] = baked_next(
            dab=dbd, ab0=aabb[:, i-1], ab1=aabb[:, i], x0=out[i-1],
        )

    # Up rows
    dbu = -dbd
    for i in range(imid - 1, -1, -1):
        out[i] = baked_next(
            dab=dbu, ab0=aabb[:, i+1], ab1=aabb[:, i], x0=out[i+1],
        )

    return aa, bb, out


def plot(aa: np.ndarray, bb: np.ndarray, x: np.ndarray) -> plt.Figure:
    fig, ax = plt.subplots()
    ax.grid()
    ax.set_xlabel('a')
    ax.set_ylabel('b')
    mesh = ax.pcolormesh(aa, bb, x, vmin=-3, vmax=3)
    fig.colorbar(mesh, label='root')
    return fig


def main() -> None:
    x0 = np.array((0.1, 1.5))
    ab0 = np.array([(0.3, 2), (0.1, 0.2)])
    err = check_grad(
        partial(foo, ab=ab0),
        lambda x: np.diag(dfoo_dx(x, ab0)),
        x0,
    )
    assert err < 1e-7

    # err = check_grad(
    #     partial(foo, x0),
    #     lambda ab: dfoo_dab(x0, ab),
    #     ab0,
    # )
    # assert err < 1e-7

    aa, bb, x = roots_2d(
        fun=foo, dfdx=dfoo_dx, dfdab=dfoo_dab,
        a0=-1, a1=1,
        b0=-3, b1=3,
        centre_estimate=0,
    )
    plot(aa, bb, x)
    plt.show()


if __name__ == '__main__':
    main()

Executes in a second or two:

To take care of disappearing roots, there are no perfect solutions. Either you need to write a flood fill algorithm, which is complicated; or you can just do a simple heuristic like

def next_roots(
    baked_root, dfdx: Trivariate, dfdab: Trivariate,
    dab: np.ndarray,
    ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
    error_bound: float = 1e-2,
) -> np.ndarray:
    input_mask = np.isfinite(x0)
    x0_masked = x0[input_mask]
    dfdxi = dfdx(x0_masked, ab0[:, input_mask])
    dxdab = dfdab(x0_masked, ab0[:, input_mask])/dfdxi
    # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
    step = (-dab) @ dxdab
    xest = x0_masked + step

    result = baked_root(args=ab1[:, input_mask], x0=xest)
    if not result.success:
        raise ValueError('Root finding failed: ' + result.message)

    xsol = result.x
    est_error = np.abs(xsol - xest)
    xsol[est_error > error_bound] = np.nan

    xnew = np.full_like(x0, fill_value=np.nan)
    xnew[input_mask] = xsol
    return xnew

If I correctly understood your question, the readily implemented solution you are looking for is fsolve from scipy.optimize, while a relatively easy to implement is the bisection method using numba which can guarantee better overall performance, given that it is just-in-time compiled.

import time
import numpy as np
from scipy.optimize import root_scalar, fsolve
import numba as nb
import functools
from matplotlib import pyplot as plt
from types import FunctionType

# --- Function definitions ---

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

def _dfoo(x, a, b):
    return -(1 + a) * np.cos(a + b - x) - 1

# Original (cached) solver for reference:
@functools.cache
def _fooroot(an, astep, bn, bstep):
    if an == 0:
        if bn == 0:
            return 0
        guessan, guessbn = an, bn - 1
    else:
        guessan, guessbn = an - 1, bn
    # Avoid reaching maximum recursion depth.
    for thisbn in range(0, guessbn, 100):
        _fooroot(0, astep, thisbn, bstep)
    for thisan in range(0, guessan, 100):
        _fooroot(thisan, astep, guessbn, bstep)
    guess = _fooroot(guessan, astep, guessbn, bstep)
    return _solve_fooroot(guess, an * astep, bn * bstep)

def _solve_fooroot(guess, a, b):
    maxlim = 1.1 * (1 + a)
    y0 = foo(guess, a, b)
    if y0 == 0:
        return guess
    dy0 = _dfoo(guess, a, b)
    estimate = -y0 / dy0
    if np.abs(estimate) < 1e-2 * maxlim:
        estimate = np.sign(estimate) * 1e-2 * maxlim
    for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
        if np.sign(foo(lim, a, b)) != np.sign(y0):
            bracket = sorted([guess, lim])
            break
    else:
        return np.nan
    sol = root_scalar(foo, bracket=bracket, args=(a, b))
    return sol.root

@np.vectorize(excluded=[2, 3, 4, 5])
@functools.cache
def fooroot(a, b):
    astep = (-1 if a < 0 else 1) * 1e-2
    bstep = (-1 if b < 0 else 1) * 1e-2
    guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
    return _solve_fooroot(guess, a, b)

# --- SciPy fsolve implementation ---
def scipy_fsolve(fun, x0, tol, mxiter, *args):
    sol, infodict, ier, mesg = fsolve(fun, x0, args=args, xtol=tol, maxfev=mxiter, full_output=True)
    if ier != 1:
        raise RuntimeError(f"fsolve did not converge: {mesg}")
    return sol[0]

# --- Numba-based solver ---
def compile_specialized_bisect(fun):
    """Returns a compiled bisection implementation for ``fun``."""
    compiled_f = nb.njit()(fun)

    def python_bisect(a, b, tol, mxiter, *args):
        """Python implementation of bisection method."""
        its = 0
        fa = compiled_f(a, *args)
        fb = compiled_f(b, *args)
        if np.abs(fa) < tol:
            return a
        if np.abs(fb) < tol:
            return b
        c = (a + b) / 2.0
        fc = compiled_f(c, *args)
        while np.abs(fc) > tol and its < mxiter:
            its += 1
            if fa * fc < 0:
                b = c
                fb = fc
            else:
                a = c
                fa = fc
            c = (a + b) / 2.0
            fc = compiled_f(c, *args)
        return c

    return nb.njit()(python_bisect)

def numba_bisect(fun, a_val, b_val, tol, mxiter, *args):
    if isinstance(fun, FunctionType):
        jit_bisect = compile_specialized_bisect(fun)
        return jit_bisect(a_val, b_val, tol, mxiter, *args)
    return fun(a_val, b_val, tol, mxiter, *args)

# Compile specialized functions to be used in numba_bisect below:
_jit_bisect_foo = compile_specialized_bisect(foo)
dummy_wakeup_call_jit = _jit_bisect_foo(-np.pi, np.pi, 1e-6, 1000, 0, 0)

# --- Timing comparison ---

params = [(0, 0), (2, 0), (2, 1)]
tol = 1e-6
mxiter = 1000

print("Original method (fooroot):")
for a, b in params:
    start = time.time()
    root_val = fooroot(a, b)
    elapsed = time.time() - start
    print(f"fooroot({a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")

print("\nNumba bisection method:")
for a, b in params:
    start = time.time()
    # root_val = _jit_bisect(-np.pi, np.pi, tol, mxiter, a, b)
    root_val = numba_bisect(_jit_bisect_foo, -np.pi, np.pi, tol, mxiter, a, b)
    elapsed = time.time() - start
    print(f"numba_bisect(foo, {a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")

print("\nSciPy fsolve method:")
for a, b in params:
    start = time.time()
    root_val = scipy_fsolve(foo, a, tol, mxiter, a, b)
    elapsed = time.time() - start
    print(f"fsolve(foo, {a}, {b}) = {root_val:.6f} in {elapsed:.6f} seconds")

# --- Plotting ---
# Plot the function foo for each parameter set along with its zero.
x_vals = np.linspace(-np.pi, np.pi, 400)
plt.figure(figsize=(10, 6))
plt.grid(True)
plt.xlabel("x")
plt.ylabel("foo(x, a, b)")

colors = ['blue', 'green', 'red']
labels = ["(a, b) = (0, 0)", "(a, b) = (2, 0)", "(a, b) = (2, 1)"]

for idx, (a, b) in enumerate(params):
    y_vals = foo(x_vals, a, b)
    plt.plot(x_vals, y_vals, color=colors[idx], label=labels[idx])
    # Compute the zero using fooroot
    fooroot_zero = fooroot(a, b)
    plt.plot(fooroot_zero, foo(fooroot_zero, a, b), 'kx', markersize=8)

    # Compute the zero using numba
    bisect_zero = numba_bisect(_jit_bisect_foo, -np.pi, np.pi, tol, mxiter, a, b)
    plt.plot(bisect_zero, foo(bisect_zero, a, b), 'ko', markersize=8)
    plt.text(bisect_zero, foo(bisect_zero, a, b), f"    {bisect_zero:.2f}",
             fontsize=9, color=colors[idx])

plt.legend()
plt.title("Function curves and computed zeros")
plt.show()

This is the output on a Colab notebook:

Original method (fooroot):
fooroot(0, 0) = 0.000000 in 0.000194 seconds
fooroot(2, 0) = 1.482951 in 0.024140 seconds
fooroot(2, 1) = 2.184404 in 0.033811 seconds

Numba bisection method:
numba_bisect(foo, 0, 0) = 0.000000 in 0.000024 seconds
numba_bisect(foo, 2, 0) = 1.482951 in 0.000005 seconds
numba_bisect(foo, 2, 1) = 2.184404 in 0.000004 seconds

SciPy fsolve method:
fsolve(foo, 0, 0) = 0.000000 in 0.000210 seconds
fsolve(foo, 2, 0) = 1.482951 in 0.000129 seconds
fsolve(foo, 2, 1) = 2.184404 in 0.000132 seconds

(Most likely this is not what you are asking, but I will still leave the answer because there are not many implementations of numba-based solvers for parametric functions on stackoverflow).

发布评论

评论列表(0)

  1. 暂无评论