This is a follow up question to: How to wrap NumPy functions in Numba-jitted code with persistent disk caching?
Background: In general Numba's implementations of Numpy functions are pretty efficient. In some cases like numpy.sort() this is not the case. I would like to use NumPy’s native sorting functionality (i.e., numpy.sort) within a Numba pipeline. The goal is to be able to make stable, cached calls to the underlying C API function for sorting across multiple Python sessions.
The C API for sorting (equivalent to ndarray.sort) is defined as:
PyObject *PyArray_Sort(PyArrayObject *self, int axis, NPY_SORTKIND kind)
.html#item-selection-and-manipulation
One potential workaround I've been thinking of is to use llvmlite.binding.load_library_permanently to stable load the NumPy extension "multiarray_umath" so that the C-function pointer to PyArray_Sort remains stable.
There are 2 potential problems which may or may not be solvable:
- PyArray_Sort must be exposed as a public symbol. This does not seem to be the case.
- PyArray_Sort uses types which might not be available (PyObject, PyArrayObject) when defining the C-function signature.
Is it possible to access NumPy’s native sorting function from external jit-compiled code? I have limited experience with C++ internals and the NumPy build system. Any insights, workarounds, or recommendations would be greatly appreciated.
Thank you for your time!
Here is an incomplete attempt which is not working:
import numpy as np
from numpy._core import _multiarray_umath as multiarray_umath
from numba import njit
from numba.core import types, typing
from llvmlite.binding import load_library_permanently, address_of_symbol
# Stable load numpy extension
np_library_path = multiarray_umath.__file__
load_library_permanently(np_library_path)
# Check if symbol is public and address can be found
np_fn_name = 'PyArray_Sort'
# The 1st issue: 'PyArray_Sort' is not a public symbol
func_addr = address_of_symbol(np_fn_name)
if func_addr is None:
raise RuntimeError(f"Could not find symbol {np_fn_name}")
print(f"Address of {np_fn_name}:", hex(func_addr))
# >>> This will raise a RuntimeError because the symbol is not publicly available
# The 2nd issue: Are there matching types to define the signature?
# >>> What is the return type for PyObject ???
return_type = types.pyobject
# >>> What is the argument type for PyArrayObject ???
arg_types = (types.pyobject, types.int64, types.int64)
np_fn_signature = typing.signature(return_type, *arg_types)
pyarray_sort = types.ExternalFunction(np_fn_name, np_fn_signature)
@njit(cache=True)
def wrapped_numpy_sort(arr, axis, kind):
return pyarray_sort(arr, axis, kind)