最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Problems when boolean indexing in Jax, getting NonConcreteBooleanIndexError - Stack Overflow

programmeradmin1浏览0评论

I'm currently trying to create a CustomProblem inheriting from the BaseProblem class in TensorNEAT which is a Jax based library. In trying to implement the evaluate function of this class, I'm using a boolean mask, but I have problems getting it to work. My code results in jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[n,n]) which I think is due to some of my arrays not having a definite shape. How do I circumvent this?

Consider this example in np:

import numpy as np

ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)

ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)

a = (ran_int[ran_bool]>0).astype(int)
print(a)

It could give an output like this:

[[2 2]
 [3 4]]
[[ True False]
 [ True  True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!

But in Jax, the same way of thinking results in the NonConcreteBooleanIndexError error I got.

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (n, 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (n, n)
        pairwise_predictions = predict - predict.T  # shape (n, n)

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold 
        print(pairs_to_keep.shape) #this prints (n, n)

        pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

        pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n)

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

I was considering using jnp.where to solve the issue, but the resulting pairwise_labels and pairwise_predictions have a different shape than what I expect (namely (n, n)) as seen in the code below:

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (n, 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (n, n)
        pairwise_predictions = predict - predict.T  # shape (n, n)

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold 
        print(pairs_to_keep.shape) #this prints (n, n)


        pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape) # shape (n, n)

        pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape) # shape (n, n)

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n ,n)

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

I fear that the differing shapes of pairwise_predictions and pairwise_labels after using jnp.where will result in a different loss than if I had just used the boolean mask as I would in np. There is also the fact that I get another error that happens later in the pipeline with the output ValueError: max() iterable argument is empty from line 143 in the pipeline.py file of TensorNeat. This is curiously circumvented by changing pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold to pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold, which probably also results in some loss that is incorrect.

Below is some code that should be enough to setup a minimal running example that is similar to my setup:

from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem

def binary_cross_entropy(prediction, target):
    return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))

# Define the custom Problem
class CustomProblem(BaseProblem):

    jitable = True  # necessary

    def __init__(self, inputs, labels, threshold):
        self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
        self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1) 
        self.threshold = threshold

    def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (len(labels), 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
        pairwise_predictions = predict - predict.T  # shape (len(inputs), len(inputs))

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
        #pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
        print(pairs_to_keep.shape)

        pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
        #pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape)

        pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
        #pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape)

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (len(labels), len(labels))

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

    @property
    def input_shape(self):
        # the input shape that the act_func expects
        return (self.inputs.shape[1],)

    @property
    def output_shape(self):
        # the output shape that the act_func returns
        return (1,)

    def show(self, state, randkey, act_func, params, *args, **kwargs):
        # showcase the performance of one individual
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)

        loss = jnp.mean(jnp.square(predict - self.labels))

        n_elements = 5
        if n_elements > len(self.inputs):
            n_elements = len(self.inputs)

        msg = f"Looking at {n_elements} first elements of input\n"
        for i in range(n_elements):
            msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
        msg += f"total loss: {loss}\n"
        print(msg)

algorithm = algorithm.NEAT(
    pop_size=10,
    survival_threshold=0.2,
    min_species_size=2,
    compatibility_threshold=3.0,  
    species_elitism=2,  
    genome=genome.DefaultGenome(
        num_inputs=768,
        num_outputs=1,
        max_nodes=769,  # must at least be same as inputs and outputs
        max_conns=768,  # must be 768 connections for the network to be fully connected
        output_transform=common.ACT.sigmoid,
        mutation=mutation.DefaultMutation(
            # no allowing adding or deleting nodes
            node_add=0.0,
            node_delete=0.0,
            # set mutation rates for edges to 0.5
            conn_add=0.5,
            conn_delete=0.5,
        ),
        node_gene=DefaultNode(),
        conn_gene=DefaultConn(),
    ),
)


INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y

problem = CustomProblem(INPUTS, LABELS, 0.25)

print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
    algorithm,
    problem,
    generation_limit=1,
    fitness_target=1,
    seed=42,
)

state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)

I'm currently trying to create a CustomProblem inheriting from the BaseProblem class in TensorNEAT which is a Jax based library. In trying to implement the evaluate function of this class, I'm using a boolean mask, but I have problems getting it to work. My code results in jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[n,n]) which I think is due to some of my arrays not having a definite shape. How do I circumvent this?

Consider this example in np:

import numpy as np

ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)

ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)

a = (ran_int[ran_bool]>0).astype(int)
print(a)

It could give an output like this:

[[2 2]
 [3 4]]
[[ True False]
 [ True  True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!

But in Jax, the same way of thinking results in the NonConcreteBooleanIndexError error I got.

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (n, 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (n, n)
        pairwise_predictions = predict - predict.T  # shape (n, n)

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold 
        print(pairs_to_keep.shape) #this prints (n, n)

        pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

        pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n)

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

I was considering using jnp.where to solve the issue, but the resulting pairwise_labels and pairwise_predictions have a different shape than what I expect (namely (n, n)) as seen in the code below:

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (n, 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (n, n)
        pairwise_predictions = predict - predict.T  # shape (n, n)

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold 
        print(pairs_to_keep.shape) #this prints (n, n)


        pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape) # shape (n, n)

        pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape) # shape (n, n)

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n ,n)

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

I fear that the differing shapes of pairwise_predictions and pairwise_labels after using jnp.where will result in a different loss than if I had just used the boolean mask as I would in np. There is also the fact that I get another error that happens later in the pipeline with the output ValueError: max() iterable argument is empty from line 143 in the pipeline.py file of TensorNeat. This is curiously circumvented by changing pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold to pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold, which probably also results in some loss that is incorrect.

Below is some code that should be enough to setup a minimal running example that is similar to my setup:

from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem

def binary_cross_entropy(prediction, target):
    return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))

# Define the custom Problem
class CustomProblem(BaseProblem):

    jitable = True  # necessary

    def __init__(self, inputs, labels, threshold):
        self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
        self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1) 
        self.threshold = threshold

    def evaluate(self, state, randkey, act_func, params):
        # do batch forward for all inputs (using jax.vamp).
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, self.inputs
        )  # should be shape (len(labels), 1)

        #calculating pairwise labels and predictions
        pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
        pairwise_predictions = predict - predict.T  # shape (len(inputs), len(inputs))

        #finding which pairs to keep
        pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
        #pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
        print(pairs_to_keep.shape)

        pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
        #pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
        pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
        print(pairwise_labels.shape)

        pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
        #pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
        pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
        print(pairwise_predictions.shape)

        # calculate loss
        loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (len(labels), len(labels))

        # reduce loss to a scalar
        loss = jnp.mean(loss)

        # return negative loss as fitness
        # TensorNEAT maximizes fitness, equivalent to minimizing loss
        return -loss

    @property
    def input_shape(self):
        # the input shape that the act_func expects
        return (self.inputs.shape[1],)

    @property
    def output_shape(self):
        # the output shape that the act_func returns
        return (1,)

    def show(self, state, randkey, act_func, params, *args, **kwargs):
        # showcase the performance of one individual
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)

        loss = jnp.mean(jnp.square(predict - self.labels))

        n_elements = 5
        if n_elements > len(self.inputs):
            n_elements = len(self.inputs)

        msg = f"Looking at {n_elements} first elements of input\n"
        for i in range(n_elements):
            msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
        msg += f"total loss: {loss}\n"
        print(msg)

algorithm = algorithm.NEAT(
    pop_size=10,
    survival_threshold=0.2,
    min_species_size=2,
    compatibility_threshold=3.0,  
    species_elitism=2,  
    genome=genome.DefaultGenome(
        num_inputs=768,
        num_outputs=1,
        max_nodes=769,  # must at least be same as inputs and outputs
        max_conns=768,  # must be 768 connections for the network to be fully connected
        output_transform=common.ACT.sigmoid,
        mutation=mutation.DefaultMutation(
            # no allowing adding or deleting nodes
            node_add=0.0,
            node_delete=0.0,
            # set mutation rates for edges to 0.5
            conn_add=0.5,
            conn_delete=0.5,
        ),
        node_gene=DefaultNode(),
        conn_gene=DefaultConn(),
    ),
)


INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y

problem = CustomProblem(INPUTS, LABELS, 0.25)

print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
    algorithm,
    problem,
    generation_limit=1,
    fitness_target=1,
    seed=42,
)

state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)
Share Improve this question edited 2 days ago user29559651 asked 2 days ago user29559651user29559651 12 bronze badges New contributor user29559651 is a new contributor to this site. Take care in asking for clarification, commenting, and answering. Check out our Code of Conduct.
Add a comment  | 

1 Answer 1

Reset to default 0

Yes, the mask operation makes the shape of the resulting array dependent on the content of the array. And jax only supports static shapes. The workaround you propose looks reasonable, with using the value -inf as a placeholder. The missing part is ignoring the zero entries in the mean. This you could achieve by a custom “masked” mean function along the lines of:

from jax import numpy as jnp
from jax import random
import jax

key = random.PRNGKey(0)

x = random.normal(key, (4, 4))

key, subkey = random.split(key)
mask = random.bernoulli(key, 0.5, (4, 4))

@jax.jit
def masked_mean(x, mask):
    return jnp.sum(x, axis=0) / jnp.sum(mask, axis=0)


masked_mean(x, mask)

I have not checked other parts of the code in detail, but e.g. the statement jnp.where(pairwise_labels > 0, True, False) has no effect. And with the masked mean you might not need the placeholder values at all.

I hope this helps!

发布评论

评论列表(0)

  1. 暂无评论