I have a jax 2d array with some nan-values
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
and want to get an array which contains for each row only the non-nan values. The resulting array has thus the same number of rows, and either less columns or the same number but with nan values padded at the end. So in this case, the result should be
array_2d = jnp.array([
[1, 2, 3],
[10 20,jnp.nan]
])
The order (among non-nan values) should stay the same.
To make things easier, I know that each row has at most k (in this case 3) non-nan values. Getting the indices for the non-nan values is very easy, but ``moving them to the front'' is harder.
I tried to work on a row-by-row basis; the following function works indeed:
# we want to vmap this over each row
def get_non_nan_values(row_vals):
ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
row_mask = ~jnp.isnan(row_vals)
ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
return ret_arr
# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]
However, I can't vmap this. Even though I payed attention that the returned array always has the same size, the line ret_vals = row_vals[row_mask]
makes problems, since this has a dynamic size. Does anyone know how to circumvent this? I believe that functions like `jnp.where' etc don't help either.
Here is the full MWE:
import jax.numpy as jnp
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
# we want to get -- efficiently -- all non-nan values per row.
# we know that each row has at most 3 non-nan values
# we will vmap this over each row
def get_non_nan_values(row_vals):
ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
row_mask = ~jnp.isnan(row_vals)
ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
return ret_arr
# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]
# we now vmap
non_nan_vals = jax.vmap(get_non_nan_values)(array_2d) # this gives error: NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
NB: The array will be very large in practice and have many nan values, while k (the number of non-nan values) is on the order of 10 or 100.
Thank you very much!
I have a jax 2d array with some nan-values
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
and want to get an array which contains for each row only the non-nan values. The resulting array has thus the same number of rows, and either less columns or the same number but with nan values padded at the end. So in this case, the result should be
array_2d = jnp.array([
[1, 2, 3],
[10 20,jnp.nan]
])
The order (among non-nan values) should stay the same.
To make things easier, I know that each row has at most k (in this case 3) non-nan values. Getting the indices for the non-nan values is very easy, but ``moving them to the front'' is harder.
I tried to work on a row-by-row basis; the following function works indeed:
# we want to vmap this over each row
def get_non_nan_values(row_vals):
ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
row_mask = ~jnp.isnan(row_vals)
ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
return ret_arr
# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]
However, I can't vmap this. Even though I payed attention that the returned array always has the same size, the line ret_vals = row_vals[row_mask]
makes problems, since this has a dynamic size. Does anyone know how to circumvent this? I believe that functions like `jnp.where' etc don't help either.
Here is the full MWE:
import jax.numpy as jnp
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
# we want to get -- efficiently -- all non-nan values per row.
# we know that each row has at most 3 non-nan values
# we will vmap this over each row
def get_non_nan_values(row_vals):
ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
row_mask = ~jnp.isnan(row_vals)
ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
return ret_arr
# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]
# we now vmap
non_nan_vals = jax.vmap(get_non_nan_values)(array_2d) # this gives error: NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
NB: The array will be very large in practice and have many nan values, while k (the number of non-nan values) is on the order of 10 or 100.
Thank you very much!
Share Improve this question edited Feb 16 at 22:11 black asked Feb 16 at 21:13 blackblack 1,2634 gold badges22 silver badges51 bronze badges 2 |4 Answers
Reset to default 2By padding the array with a fill value at the end of each row first, you can rely on jnp.nonzero
and its size
and fill_value
arguments, which define a fixed output size and fill value index, when the size requirement is not met. Here is a minimal example:
import jax.numpy as jnp
import jax
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
@jax.vmap
def get_non_nan_values(row_vals, size=3):
padded = jnp.pad(row_vals, (0, 1), constant_values=jnp.nan)
non_nan = jnp.nonzero(~jnp.isnan(padded), size=size, fill_value=-1)
return padded[non_nan]
get_non_nan_values(array_2d)
Which returns:
Array([[ 1., 2., 3.],
[10., 20., nan]], dtype=float32)
I think this solution is a bit more compact and clearer in intend, however I have not checked the performance.
I hope this helps!
I think you can do what you want with this function, which rather than sorting the array (as I commented), sorts and masks the indices of the non-nan values:
from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.jit, static_argnums=(1,))
def func(array, k=3):
m, n = array.shape[-2:]
indices = jnp.broadcast_to(jnp.arange(n)[None, :], (m, n))
sorted_masked_indices = jnp.sort(jnp.where(jnp.isnan(array), jnp.nan, indices))
array_rearranged = array[jnp.arange(m)[:, None], sorted_masked_indices.astype(int)]
return jnp.where(jnp.isnan(sorted_masked_indices), jnp.nan, array_rearranged)[:, :k]
Test:
import numpy as np
rng = np.random.default_rng(0)
k = 3
a = rng.random((12, 6))
a[np.arange(12)[:, None], rng.integers(0, 6, (12, 6))] = np.nan
print(a)
print(func(a, k=k))
Gives:
[[0.63696169 nan nan 0.01652764 0.81327024 nan]
[ nan 0.72949656 nan nan 0.81585355 nan]
[ nan 0.03358558 nan nan nan nan]
[0.29971189 nan nan nan nan 0.64718951]
[ nan nan nan 0.98083534 nan 0.65045928]
[ nan nan 0.13509651 0.72148834 nan nan]
[ nan 0.88948783 0.93404352 0.3577952 nan nan]
[ nan 0.33791123 0.391619 0.89027435 nan nan]
[ nan 0.83264415 nan nan 0.87648423 nan]
[0.33611706 nan nan 0.79632427 nan 0.0520213 ]
[ nan nan 0.09075305 0.58033239 nan nan]
[ nan 0.94211311 nan nan 0.62910815 nan]]
[[0.6369617 0.01652764 0.8132702 ]
[0.72949654 0.81585354 nan]
[0.03358557 nan nan]
[0.29971188 0.6471895 nan]
[0.9808353 0.6504593 nan]
[0.1350965 0.72148836 nan]
[0.88948786 0.9340435 0.3577952 ]
[0.33791122 0.391619 0.89027435]
[0.83264416 0.8764842 nan]
[0.33611706 0.79632425 0.0520213 ]
[0.09075305 0.5803324 nan]
[0.9421131 0.62910813 nan]]
With the stable=True
option, argsort on a boolean array is guaranteed to preserve the relative order between True and False elements. So this should do the trick:
def get_non_nan_values(row_vals):
return row_vals[jnp.argsort(jnp.isnan(rowvals), stable=True)[:3]]
However, for wide rows, sorting the entire row seems unnecessary when we already know there are only at most 3 non-nan values. So another simple approach using jax.lax.top_k
:
def get_top_3_non_nan(row_vals):
return row_vals[jax.lax.top_k(~jnp.isnan(row_vals), 3)[1]]
I would do this using vmap
of argsort
of isnan
:
import jax
import jax.numpy as jnp
array_2d = jnp.array([
[jnp.nan, 1, 2, jnp.nan, 3],
[10 ,jnp.nan, jnp.nan, 20,jnp.nan]
])
result = jax.vmap(lambda x: x[jnp.argsort(jnp.isnan(x))])(array_2d)
print(result)
# [[ 1. 2. 3. nan nan]
# [10. 20. nan nan nan]]
This approach uses static shapes, and thus will be compatible with jit
.
jnp.sort(array_2d)[:, :k]
? which works in this example but just because the values are already sorted – Nin17 Commented Feb 16 at 21:43