I created a custom pytree node following this tutorial. My custom pytree node works for many operations and transformations, except grad.
Here is my code.
import jax
import jax.numpy as jnp
from jax import tree_util
class MyLinear:
def __init__(self, w, b):
self.w = jnp.array(w)
self.b = jnp.array(b)
def __call__(self, x):
return jnp.dot(x, self.w) + self.b
def tree_flatten(self):
return (self.w, self.b), ()
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(MyLinear, MyLinear.tree_flatten, MyLinear.tree_unflatten)
Then I tested this code
inputs = jnp.ones((2,2))
input = inputs[0]
mylinear = MyLinear([1.0, 1.0], 1.)
print(tree_util.tree_map(lambda x: x+1, input)) # [2. 2.] works as expected
print(jax.vmap(mylinear)(inputs)) # [3. 3.] works as expected
print(jax.jit(jax.vmap(mylinear))(inputs)) # [3. 3.] works as expected
def loss_fn(model, x):
out = jax.vmap(model)(x)
return jnp.sum(out ** 2)
print(loss_fn(mylinear, inputs)) # 18.0 works as expected
print(jax.grad(loss_fn)(mylinear, inputs)) # <__main__.MyLinear object at 0x10e317890> doesn't work as expected
It seems grad doesn't recognize object of MyLinear as a flattenable tree like object of tuple, list, or dict. What should the code be so that objects of this class recognizable by all jax transformations? Thank you for the help.
I created a custom pytree node following this tutorial. My custom pytree node works for many operations and transformations, except grad.
Here is my code.
import jax
import jax.numpy as jnp
from jax import tree_util
class MyLinear:
def __init__(self, w, b):
self.w = jnp.array(w)
self.b = jnp.array(b)
def __call__(self, x):
return jnp.dot(x, self.w) + self.b
def tree_flatten(self):
return (self.w, self.b), ()
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(MyLinear, MyLinear.tree_flatten, MyLinear.tree_unflatten)
Then I tested this code
inputs = jnp.ones((2,2))
input = inputs[0]
mylinear = MyLinear([1.0, 1.0], 1.)
print(tree_util.tree_map(lambda x: x+1, input)) # [2. 2.] works as expected
print(jax.vmap(mylinear)(inputs)) # [3. 3.] works as expected
print(jax.jit(jax.vmap(mylinear))(inputs)) # [3. 3.] works as expected
def loss_fn(model, x):
out = jax.vmap(model)(x)
return jnp.sum(out ** 2)
print(loss_fn(mylinear, inputs)) # 18.0 works as expected
print(jax.grad(loss_fn)(mylinear, inputs)) # <__main__.MyLinear object at 0x10e317890> doesn't work as expected
It seems grad doesn't recognize object of MyLinear as a flattenable tree like object of tuple, list, or dict. What should the code be so that objects of this class recognizable by all jax transformations? Thank you for the help.
Share Improve this question asked Mar 20 at 6:53 YahyaYahya 212 bronze badges1 Answer
Reset to default 1I think there is a misunderstanding here. From what I can tell the code works exactly as intended. JAX operates on "structs of arrays". So your MyLinear
class works as a data container for arrays, as well as their gradients. When applying jax.grad()
to a PyTree, JAX will return the same PyTree, but containing the gradients of the corresponding nodes. So you can access the individual gradient like so:
inputs = jnp.ones((2,2))
mylinear = MyLinear([1.0, 1.0], 1.)
def loss_fn(model, x):
out = jax.vmap(model)(x)
return jnp.sum(out ** 2)
grads = jax.grad(loss_fn)(mylinear, inputs)
print(grads.w)
print(grads.b)
I hope this clarifies the behavior!