最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

indexing - Jax numpy extracting non-nan values gives NonConcreteBooleanIndexError - Stack Overflow

programmeradmin4浏览0评论

I have a jax 2d array with some nan-values

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

and want to get an array which contains for each row only the non-nan values. The resulting array has thus the same number of rows, and either less columns or the same number but with nan values padded at the end. So in this case, the result should be

array_2d = jnp.array([
    [1,   2,      3],
    [10  20,jnp.nan]
    ])

The order (among non-nan values) should stay the same.

To make things easier, I know that each row has at most k (in this case 3) non-nan values. Getting the indices for the non-nan values is very easy, but ``moving them to the front'' is harder.

I tried to work on a row-by-row basis; the following function works indeed:

# we want to vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

However, I can't vmap this. Even though I payed attention that the returned array always has the same size, the line ret_vals = row_vals[row_mask] makes problems, since this has a dynamic size. Does anyone know how to circumvent this? I believe that functions like `jnp.where' etc don't help either.

Here is the full MWE:

import jax.numpy as jnp

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

# we want to get -- efficiently -- all non-nan values per row.
# we know that each row has at most 3 non-nan values

# we will vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

# we now vmap
non_nan_vals = jax.vmap(get_non_nan_values)(array_2d) # this gives error: NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

NB: The array will be very large in practice and have many nan values, while k (the number of non-nan values) is on the order of 10 or 100.

Thank you very much!

I have a jax 2d array with some nan-values

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

and want to get an array which contains for each row only the non-nan values. The resulting array has thus the same number of rows, and either less columns or the same number but with nan values padded at the end. So in this case, the result should be

array_2d = jnp.array([
    [1,   2,      3],
    [10  20,jnp.nan]
    ])

The order (among non-nan values) should stay the same.

To make things easier, I know that each row has at most k (in this case 3) non-nan values. Getting the indices for the non-nan values is very easy, but ``moving them to the front'' is harder.

I tried to work on a row-by-row basis; the following function works indeed:

# we want to vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

However, I can't vmap this. Even though I payed attention that the returned array always has the same size, the line ret_vals = row_vals[row_mask] makes problems, since this has a dynamic size. Does anyone know how to circumvent this? I believe that functions like `jnp.where' etc don't help either.

Here is the full MWE:

import jax.numpy as jnp

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

# we want to get -- efficiently -- all non-nan values per row.
# we know that each row has at most 3 non-nan values

# we will vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

# we now vmap
non_nan_vals = jax.vmap(get_non_nan_values)(array_2d) # this gives error: NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

NB: The array will be very large in practice and have many nan values, while k (the number of non-nan values) is on the order of 10 or 100.

Thank you very much!

Share Improve this question edited Feb 16 at 22:11 black asked Feb 16 at 21:13 blackblack 1,2634 gold badges22 silver badges51 bronze badges 2
  • do the values have to remain in the same order? ie could you just do jnp.sort(array_2d)[:, :k]? which works in this example but just because the values are already sorted – Nin17 Commented Feb 16 at 21:43
  • Yes, they should remain in the same order, thanks for the question, I'll clarify that. – black Commented Feb 16 at 22:10
Add a comment  | 

4 Answers 4

Reset to default 2

By padding the array with a fill value at the end of each row first, you can rely on jnp.nonzero and its size and fill_value arguments, which define a fixed output size and fill value index, when the size requirement is not met. Here is a minimal example:

import jax.numpy as jnp
import jax

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])


@jax.vmap
def get_non_nan_values(row_vals, size=3):
    padded = jnp.pad(row_vals, (0, 1), constant_values=jnp.nan)
    non_nan = jnp.nonzero(~jnp.isnan(padded), size=size, fill_value=-1)
    return padded[non_nan]

get_non_nan_values(array_2d)

Which returns:

Array([[ 1.,  2.,  3.],
       [10., 20., nan]], dtype=float32)

I think this solution is a bit more compact and clearer in intend, however I have not checked the performance.

I hope this helps!

I think you can do what you want with this function, which rather than sorting the array (as I commented), sorts and masks the indices of the non-nan values:

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnums=(1,))
def func(array, k=3):
    m, n = array.shape[-2:]
    indices = jnp.broadcast_to(jnp.arange(n)[None, :], (m, n))
    sorted_masked_indices = jnp.sort(jnp.where(jnp.isnan(array), jnp.nan, indices))
    array_rearranged = array[jnp.arange(m)[:, None], sorted_masked_indices.astype(int)]
    return jnp.where(jnp.isnan(sorted_masked_indices), jnp.nan, array_rearranged)[:, :k]

Test:

import numpy as np
rng = np.random.default_rng(0)
k = 3

a = rng.random((12, 6))
a[np.arange(12)[:, None], rng.integers(0, 6, (12, 6))] = np.nan

print(a)
print(func(a, k=k))

Gives:

[[0.63696169        nan        nan 0.01652764 0.81327024        nan]
 [       nan 0.72949656        nan        nan 0.81585355        nan]
 [       nan 0.03358558        nan        nan        nan        nan]
 [0.29971189        nan        nan        nan        nan 0.64718951]
 [       nan        nan        nan 0.98083534        nan 0.65045928]
 [       nan        nan 0.13509651 0.72148834        nan        nan]
 [       nan 0.88948783 0.93404352 0.3577952         nan        nan]
 [       nan 0.33791123 0.391619   0.89027435        nan        nan]
 [       nan 0.83264415        nan        nan 0.87648423        nan]
 [0.33611706        nan        nan 0.79632427        nan 0.0520213 ]
 [       nan        nan 0.09075305 0.58033239        nan        nan]
 [       nan 0.94211311        nan        nan 0.62910815        nan]]
[[0.6369617  0.01652764 0.8132702 ]
 [0.72949654 0.81585354        nan]
 [0.03358557        nan        nan]
 [0.29971188 0.6471895         nan]
 [0.9808353  0.6504593         nan]
 [0.1350965  0.72148836        nan]
 [0.88948786 0.9340435  0.3577952 ]
 [0.33791122 0.391619   0.89027435]
 [0.83264416 0.8764842         nan]
 [0.33611706 0.79632425 0.0520213 ]
 [0.09075305 0.5803324         nan]
 [0.9421131  0.62910813        nan]]

With the stable=True option, argsort on a boolean array is guaranteed to preserve the relative order between True and False elements. So this should do the trick:

def get_non_nan_values(row_vals):
    return row_vals[jnp.argsort(jnp.isnan(rowvals), stable=True)[:3]]

However, for wide rows, sorting the entire row seems unnecessary when we already know there are only at most 3 non-nan values. So another simple approach using jax.lax.top_k:

def get_top_3_non_nan(row_vals):
  return row_vals[jax.lax.top_k(~jnp.isnan(row_vals), 3)[1]]

I would do this using vmap of argsort of isnan:

import jax
import jax.numpy as jnp

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
])

result = jax.vmap(lambda x: x[jnp.argsort(jnp.isnan(x))])(array_2d)
print(result)
# [[ 1.  2.  3. nan nan]
#  [10. 20. nan nan nan]]

This approach uses static shapes, and thus will be compatible with jit.

发布评论

评论列表(0)

  1. 暂无评论