最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Numba promotes types differently to numpy - Stack Overflow

programmeradmin4浏览0评论

When adding (or multiplying, dividing, subtracting etc…) a python float to a numpy array, in numpy the dtype of the array is preserved, whereas in numba the array is promoted to float64. How can I modify the overload of ndarray.__add__ etc… to change the dtype of the python float to match that of the array so that the result has the same dtype?

Ideally, I don't want to have to modify my functions, rather just implement a new overload of the underlying addition etc functions, as there are many instances of this in my code.

Code to demonstrate the issue, would like consistency with numpy in a function decorated with njit:

import numpy as np
import numba as nb

def func(array):
    return array + 1.0

numba_func = nb.njit(func)

a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)

for i in (a_f64, a_f32):
    print(i.dtype)
    print(func(i).dtype)
    print(numba_func(i).dtype, end="\n\n")

Output (with numpy 2.1.3 and numba 0.61.0):

float64
float64
float64

float32
float32
float64

When adding (or multiplying, dividing, subtracting etc…) a python float to a numpy array, in numpy the dtype of the array is preserved, whereas in numba the array is promoted to float64. How can I modify the overload of ndarray.__add__ etc… to change the dtype of the python float to match that of the array so that the result has the same dtype?

Ideally, I don't want to have to modify my functions, rather just implement a new overload of the underlying addition etc functions, as there are many instances of this in my code.

Code to demonstrate the issue, would like consistency with numpy in a function decorated with njit:

import numpy as np
import numba as nb

def func(array):
    return array + 1.0

numba_func = nb.njit(func)

a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)

for i in (a_f64, a_f32):
    print(i.dtype)
    print(func(i).dtype)
    print(numba_func(i).dtype, end="\n\n")

Output (with numpy 2.1.3 and numba 0.61.0):

float64
float64
float64

float32
float32
float64
Share Improve this question asked yesterday Nin17Nin17 3,4922 gold badges5 silver badges17 bronze badges 2
  • I thought this might have been one of the things that changed in NumPy 2.0, but apparently not - the dtype behavior is the same on 1.21.6. – user2357112 Commented yesterday
  • I believe it comes from this – roganjosh Commented yesterday
Add a comment  | 

1 Answer 1

Reset to default 1

It comes from this

Numpy will most often return a float64 as a result of a computation with mixed integer and floating-point operands (a typical example is the power operator **). Numba by contrast will select the highest precision amongst the floating-point operands, so for example float32 ** int32 will return a float32, regardless of the input values. This makes performance characteristics easier to predict, but you should explicitly cast the input to float64 if you need the extra precision.

You can fix your example simply by using:

import numpy as np
import numba as nb

def func(array):
    return array + np.float32(1.0)

numba_func = nb.njit(func)

a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)
# print(a_f32)
for i in (a_f64, a_f32):
    print(i.dtype)
    print(func(i).dtype)
    print(numba_func(i).dtype, end="\n\n")

The problem is that your 1.0 is being interpreted as 64-bit and it upcasts the lot.

I don't think this makes a whole lot of sense, but it fixes the upcasting to 64-bit.

This is my output:

float64
float64
float64

float32
float32
float32

This is with numpy '1.26.4' and numba '0.60.0' in Python 3.12.0. I don't think I've solved the whole problem here.

发布评论

评论列表(0)

  1. 暂无评论