最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - How to make a custom pytree node works with grad in JAX - Stack Overflow

programmeradmin2浏览0评论

I created a custom pytree node following this tutorial. My custom pytree node works for many operations and transformations, except grad.

Here is my code.

import jax
import jax.numpy as jnp
from jax import tree_util

class MyLinear:
    def __init__(self, w, b):
        self.w = jnp.array(w)
        self.b = jnp.array(b)

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

    def tree_flatten(self):
        return (self.w, self.b), ()

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

tree_util.register_pytree_node(MyLinear, MyLinear.tree_flatten, MyLinear.tree_unflatten)

Then I tested this code

inputs = jnp.ones((2,2))
input = inputs[0]
mylinear = MyLinear([1.0, 1.0], 1.)
print(tree_util.tree_map(lambda x: x+1, input)) # [2. 2.] works as expected
print(jax.vmap(mylinear)(inputs))   # [3. 3.] works as expected
print(jax.jit(jax.vmap(mylinear))(inputs))  # [3. 3.] works as expected
def loss_fn(model, x):
    out = jax.vmap(model)(x)
    return jnp.sum(out ** 2)
print(loss_fn(mylinear, inputs))    # 18.0 works as expected
print(jax.grad(loss_fn)(mylinear, inputs))  # <__main__.MyLinear object at 0x10e317890> doesn't work as expected

It seems grad doesn't recognize object of MyLinear as a flattenable tree like object of tuple, list, or dict. What should the code be so that objects of this class recognizable by all jax transformations? Thank you for the help.

I created a custom pytree node following this tutorial. My custom pytree node works for many operations and transformations, except grad.

Here is my code.

import jax
import jax.numpy as jnp
from jax import tree_util

class MyLinear:
    def __init__(self, w, b):
        self.w = jnp.array(w)
        self.b = jnp.array(b)

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

    def tree_flatten(self):
        return (self.w, self.b), ()

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

tree_util.register_pytree_node(MyLinear, MyLinear.tree_flatten, MyLinear.tree_unflatten)

Then I tested this code

inputs = jnp.ones((2,2))
input = inputs[0]
mylinear = MyLinear([1.0, 1.0], 1.)
print(tree_util.tree_map(lambda x: x+1, input)) # [2. 2.] works as expected
print(jax.vmap(mylinear)(inputs))   # [3. 3.] works as expected
print(jax.jit(jax.vmap(mylinear))(inputs))  # [3. 3.] works as expected
def loss_fn(model, x):
    out = jax.vmap(model)(x)
    return jnp.sum(out ** 2)
print(loss_fn(mylinear, inputs))    # 18.0 works as expected
print(jax.grad(loss_fn)(mylinear, inputs))  # <__main__.MyLinear object at 0x10e317890> doesn't work as expected

It seems grad doesn't recognize object of MyLinear as a flattenable tree like object of tuple, list, or dict. What should the code be so that objects of this class recognizable by all jax transformations? Thank you for the help.

Share Improve this question asked Mar 20 at 6:53 YahyaYahya 212 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 1

I think there is a misunderstanding here. From what I can tell the code works exactly as intended. JAX operates on "structs of arrays". So your MyLinear class works as a data container for arrays, as well as their gradients. When applying jax.grad() to a PyTree, JAX will return the same PyTree, but containing the gradients of the corresponding nodes. So you can access the individual gradient like so:

inputs = jnp.ones((2,2))
mylinear = MyLinear([1.0, 1.0], 1.)

def loss_fn(model, x):
    out = jax.vmap(model)(x)
    return jnp.sum(out ** 2)

grads = jax.grad(loss_fn)(mylinear, inputs)
print(grads.w)
print(grads.b)

I hope this clarifies the behavior!

发布评论

评论列表(0)

  1. 暂无评论