I'm currently trying to create a CustomProblem
inheriting from the BaseProblem class in TensorNEAT which is a Jax based library. In trying to implement the evaluate
function of this class, I'm using a boolean mask, but I have problems getting it to work. My code results in jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[n,n])
which I think is due to some of my arrays not having a definite shape. How do I circumvent this?
Consider this example in np
:
import numpy as np
ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)
ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)
a = (ran_int[ran_bool]>0).astype(int)
print(a)
It could give an output like this:
[[2 2]
[3 4]]
[[ True False]
[ True True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!
But in Jax, the same way of thinking results in the NonConcreteBooleanIndexError
error I got.
#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)
pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask
pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n)
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
I was considering using jnp.where
to solve the issue, but the resulting pairwise_labels
and pairwise_predictions
have a different shape than what I expect (namely (n, n)
) as seen in the code below:
#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)
pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) # shape (n, n)
pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) # shape (n, n)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n ,n)
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
I fear that the differing shapes of pairwise_predictions
and pairwise_labels
after using jnp.where
will result in a different loss than if I had just used the boolean mask as I would in np
. There is also the fact that I get another error that happens later in the pipeline with the output ValueError: max() iterable argument is empty
from line 143 in the pipeline.py
file of TensorNeat
. This is curiously circumvented by changing pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
to pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold
, which probably also results in some loss that is incorrect.
Below is some code that should be enough to setup a minimal running example that is similar to my setup:
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
#pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
print(pairs_to_keep.shape)
pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape)
pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (len(labels), len(labels))
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
loss = jnp.mean(jnp.square(predict - self.labels))
n_elements = 5
if n_elements > len(self.inputs):
n_elements = len(self.inputs)
msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
msg += f"total loss: {loss}\n"
print(msg)
algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25)
print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
fitness_target=1,
seed=42,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)
I'm currently trying to create a CustomProblem
inheriting from the BaseProblem class in TensorNEAT which is a Jax based library. In trying to implement the evaluate
function of this class, I'm using a boolean mask, but I have problems getting it to work. My code results in jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[n,n])
which I think is due to some of my arrays not having a definite shape. How do I circumvent this?
Consider this example in np
:
import numpy as np
ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)
ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)
a = (ran_int[ran_bool]>0).astype(int)
print(a)
It could give an output like this:
[[2 2]
[3 4]]
[[ True False]
[ True True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!
But in Jax, the same way of thinking results in the NonConcreteBooleanIndexError
error I got.
#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)
pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask
pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n)
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
I was considering using jnp.where
to solve the issue, but the resulting pairwise_labels
and pairwise_predictions
have a different shape than what I expect (namely (n, n)
) as seen in the code below:
#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)
pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) # shape (n, n)
pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) # shape (n, n)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n ,n)
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
I fear that the differing shapes of pairwise_predictions
and pairwise_labels
after using jnp.where
will result in a different loss than if I had just used the boolean mask as I would in np
. There is also the fact that I get another error that happens later in the pipeline with the output ValueError: max() iterable argument is empty
from line 143 in the pipeline.py
file of TensorNeat
. This is curiously circumvented by changing pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
to pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold
, which probably also results in some loss that is incorrect.
Below is some code that should be enough to setup a minimal running example that is similar to my setup:
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
#pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
print(pairs_to_keep.shape)
pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape)
pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (len(labels), len(labels))
# reduce loss to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
loss = jnp.mean(jnp.square(predict - self.labels))
n_elements = 5
if n_elements > len(self.inputs):
n_elements = len(self.inputs)
msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
msg += f"total loss: {loss}\n"
print(msg)
algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25)
print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
fitness_target=1,
seed=42,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)
Share
Improve this question
edited 2 days ago
user29559651
asked 2 days ago
user29559651user29559651
12 bronze badges
New contributor
user29559651 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
1 Answer
Reset to default 0Yes, the mask operation makes the shape of the resulting array dependent on the content of the array. And jax only supports static shapes. The workaround you propose looks reasonable, with using the value -inf
as a placeholder. The missing part is ignoring the zero entries in the mean. This you could achieve by a custom “masked” mean function along the lines of:
from jax import numpy as jnp from jax import random import jax key = random.PRNGKey(0) x = random.normal(key, (4, 4)) key, subkey = random.split(key) mask = random.bernoulli(key, 0.5, (4, 4)) @jax.jit def masked_mean(x, mask): return jnp.sum(x, axis=0) / jnp.sum(mask, axis=0) masked_mean(x, mask)
I have not checked other parts of the code in detail, but e.g. the statement jnp.where(pairwise_labels > 0, True, False)
has no effect. And with the masked mean you might not need the placeholder values at all.
I hope this helps!