When adding (or multiplying, dividing, subtracting etc…) a python float to a numpy array, in numpy the dtype of the array is preserved, whereas in numba the array is promoted to float64. How can I modify the overload of ndarray.__add__
etc… to change the dtype of the python float to match that of the array so that the result has the same dtype?
Ideally, I don't want to have to modify my functions, rather just implement a new overload of the underlying addition etc functions, as there are many instances of this in my code.
Code to demonstrate the issue, would like consistency with numpy in a function decorated with njit
:
import numpy as np
import numba as nb
def func(array):
return array + 1.0
numba_func = nb.njit(func)
a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)
for i in (a_f64, a_f32):
print(i.dtype)
print(func(i).dtype)
print(numba_func(i).dtype, end="\n\n")
Output (with numpy 2.1.3 and numba 0.61.0):
float64
float64
float64
float32
float32
float64
When adding (or multiplying, dividing, subtracting etc…) a python float to a numpy array, in numpy the dtype of the array is preserved, whereas in numba the array is promoted to float64. How can I modify the overload of ndarray.__add__
etc… to change the dtype of the python float to match that of the array so that the result has the same dtype?
Ideally, I don't want to have to modify my functions, rather just implement a new overload of the underlying addition etc functions, as there are many instances of this in my code.
Code to demonstrate the issue, would like consistency with numpy in a function decorated with njit
:
import numpy as np
import numba as nb
def func(array):
return array + 1.0
numba_func = nb.njit(func)
a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)
for i in (a_f64, a_f32):
print(i.dtype)
print(func(i).dtype)
print(numba_func(i).dtype, end="\n\n")
Output (with numpy 2.1.3 and numba 0.61.0):
float64
float64
float64
float32
float32
float64
Share
Improve this question
asked yesterday
Nin17Nin17
3,4922 gold badges5 silver badges17 bronze badges
2
- I thought this might have been one of the things that changed in NumPy 2.0, but apparently not - the dtype behavior is the same on 1.21.6. – user2357112 Commented yesterday
- I believe it comes from this – roganjosh Commented yesterday
1 Answer
Reset to default 1It comes from this
Numpy will most often return a float64 as a result of a computation with mixed integer and floating-point operands (a typical example is the power operator **). Numba by contrast will select the highest precision amongst the floating-point operands, so for example float32 ** int32 will return a float32, regardless of the input values. This makes performance characteristics easier to predict, but you should explicitly cast the input to float64 if you need the extra precision.
You can fix your example simply by using:
import numpy as np
import numba as nb
def func(array):
return array + np.float32(1.0)
numba_func = nb.njit(func)
a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)
# print(a_f32)
for i in (a_f64, a_f32):
print(i.dtype)
print(func(i).dtype)
print(numba_func(i).dtype, end="\n\n")
The problem is that your 1.0
is being interpreted as 64-bit and it upcasts the lot.
I don't think this makes a whole lot of sense, but it fixes the upcasting to 64-bit.
This is my output:
float64
float64
float64
float32
float32
float32
This is with numpy '1.26.4'
and numba '0.60.0'
in Python 3.12.0
. I don't think I've solved the whole problem here.