最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - How to overload numpy.atleast_2d in Numba to include non-ndarray data types? - Stack Overflow

programmeradmin3浏览0评论

Numba does not have an implementation for numpy.atleast_1d() and numpy.atleast_2d() which includes other data types than arrays. I've attempted to overload atleast_2d, but I'm having trouble handling 2D reflected and typed lists correctly. I need help with the proper type checks for 2D reflected and typed lists to prevent errors. It's confusng to me.

Does anyone know the correct way to check and handle 2D lists in Numba's overload operation?

Here is my attempt:

import numpy as np
from numba import njit, types, typed
from numba.extending import overload
from numba.core.errors import TypingError

@overload(np.atleast_2d)
def ovl_atleast_2d(a):
    '''
    Implement np.atleast_2d.

    Example:
        @njit
        def use_atleast_2d(a):
            return np.atleast_2d(a).ndim

        print(f"Result for scalar:   {use_atleast_2d(1.)} ndim")
        print(f"Result for array 0D: {use_atleast_2d(np.array(1.))} ndim")
        print(f"Result for array 1D: {use_atleast_2d(np.ones((3,)))} ndim")
        print(f"Result for array 2D: {use_atleast_2d(np.ones((1,1)))} ndim")
        print(f"Result for tuple 1D: {use_atleast_2d((1,2,3))} ndim")
        print(f"Result for tuple 2D: {use_atleast_2d(((1,2,3), (4,5,6)))} ndim")
        print(f"Result for pylist 1D:  {use_atleast_2d([1,2,3])} ndim")

        # print(f"Result for pylist 2D:  {use_atleast_2d([[1,2,3], [4,5,6]])} ndim")
        # => TypeError: cannot reflect element of reflected container: reflected list(reflected list(int64)<iv=None>)<iv=None>

        print(f"Result for nblist 1D: {use_atleast_2d(typed.List([1,2,3]))} ndim")

        nblist2d = typed.List([typed.List([1,2,3]), typed.List([4,5,6])])
        # print(f"Result for tlist 2D: {use_atleast_2d(nblist2d)} ndim")
        # => KeyError: 'Can only index numba types with slices with no start or stop, got 0.'

        # Expected output:
        # Result for scalar:   2 ndim
        # Result for array 0D: 2 ndim
        # Result for array 1D: 2 ndim
        # Result for array 2D: 2 ndim
        # Result for tuple 1D: 2 ndim
        # Result for tuple 2D: 2 ndim
        # Result for pylist 1D:  2 ndim
        # Result for pylist 2D:  2 ndim  (Error)
        # Result for nblist 1D:  2 ndim
        # Result for nblist 2D:  2 ndim  (Error)
    '''
    if isinstance(a, types.Array):
        if a.ndim == 0:
            return lambda a: np.array([[a]])
        elif a.ndim == 1:
            return lambda a: a[np.newaxis, :]
        else:
            return lambda a: a

    elif isinstance(a, (types.Number, types.Boolean)):
        return lambda a: np.array([[a]])

    elif isinstance(a, (types.Sequence, types.Tuple)):
        # For a Python sequence or tuple, first convert to an array then ensure it is at least 2-D.
        return lambda a: np.atleast_2d(np.array(a))

    elif isinstance(a, types.containers.ListType):

        # 1D-list
        if isinstance(a[0].dtype, (types.Number, types.Boolean)):

            target_dtype = a[0].dtype

            def impl(a):
                ret = np.empty((len(a), 1), dtype=target_dtype)
                for i, v in enumerate(a):
                    ret[i, 0] = v
                return ret
            return impl

        # 2D-list
        elif isinstance(a[0].dtype, types.containers.ListType):
            if isinstance(a[0][0].dtype, (types.Number, types.Boolean)):

                target_dtype = a[0][0].dtype

                def impl(a):
                    nrows = len(a)
                    ncols = len(a[0])
                    ret = np.empty((nrows, ncols), dtype=target_dtype)
                    for i in range(nrows):
                        for k in range(ncols):
                            ret[i, k] = a[i][k]
                    return ret
                return impl
    else:
        raise TypingError("Argument can't be converted into ndarray.")

Here is the source code for the atleast implementations: .py#L5780

Edit: I have found a remark within the related function asarray:

# Nested lists cannot be unpacked, therefore only single lists are
# permitted and these conform to Sequence and can be unpacked along on
# the same path as Tuple.

.py#L4278

I assume if nested lists cannot be unpacked in asarray, it's similar for atleast_2D.

发布评论

评论列表(0)

  1. 暂无评论