最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - RMSNorm derivative using sympy -- problem with summation over fixed number of elements - Stack Overflow

programmeradmin5浏览0评论

I have following sympy equation for RMSNorm (easier to see in Jupyter notebook)

import sympy as sp

# Define the symbols
x = sp.Symbol('x')  # Input variable
n = sp.Symbol('n')  # Number of elements
gamma = sp.Symbol('gamma')
epsilon = sp.Symbol('epsilon')  # Small constant to avoid division by zero

# Define the RMS normalization equation
mean_square = sp.Sum(x**2, (x, 1, n)) / n
rms = sp.sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms

# Display the equation
sp.pprint(fwd_out)

I have issue with the rms term when I take the derivative of fwd_out wrt x as follows:

d_activation = sp.diff(fwd_out, x)

Sympy does not consider rms as a function of x -- it considers it as a constant, as it evaluates rms over n, following displays 0:

sp.diff(rms, x)

But as per the RMSNorm paper, rms should considered as a function of x.

Is there a way where sympy can be forced to consider rms as a function of x?

I am using Python 3.12.9 and Sympy 1.12.1.


Complete answer based on @smichr 's answer:

from sympy import *
from sympy.abc import n, gamma, epsilon

x = IndexedBase("x")
i = symbols('i', cls = Idx)
mean_squared = Sum(x[i] ** 2, (i, 1, n)) / n
rms = sqrt(mean_squared + epsilon)
fwd_out = x * gamma / r

# diff wrt x[i]
d_fwd_out = diff(fwd_out, x[i])

d_rms = diff(rms, x[i])

Ref:

RMSNorm Paper: .07467 Pytorch API: .nn.modules.normalization.RMSNorm.html

I have following sympy equation for RMSNorm (easier to see in Jupyter notebook)

import sympy as sp

# Define the symbols
x = sp.Symbol('x')  # Input variable
n = sp.Symbol('n')  # Number of elements
gamma = sp.Symbol('gamma')
epsilon = sp.Symbol('epsilon')  # Small constant to avoid division by zero

# Define the RMS normalization equation
mean_square = sp.Sum(x**2, (x, 1, n)) / n
rms = sp.sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms

# Display the equation
sp.pprint(fwd_out)

I have issue with the rms term when I take the derivative of fwd_out wrt x as follows:

d_activation = sp.diff(fwd_out, x)

Sympy does not consider rms as a function of x -- it considers it as a constant, as it evaluates rms over n, following displays 0:

sp.diff(rms, x)

But as per the RMSNorm paper, rms should considered as a function of x.

Is there a way where sympy can be forced to consider rms as a function of x?

I am using Python 3.12.9 and Sympy 1.12.1.


Complete answer based on @smichr 's answer:

from sympy import *
from sympy.abc import n, gamma, epsilon

x = IndexedBase("x")
i = symbols('i', cls = Idx)
mean_squared = Sum(x[i] ** 2, (i, 1, n)) / n
rms = sqrt(mean_squared + epsilon)
fwd_out = x * gamma / r

# diff wrt x[i]
d_fwd_out = diff(fwd_out, x[i])

d_rms = diff(rms, x[i])

Ref:

RMSNorm Paper: https://arxiv./pdf/1910.07467 Pytorch API: https://pytorch./docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html

Share Improve this question edited Mar 21 at 4:52 jared 9,2233 gold badges15 silver badges43 bronze badges asked Mar 20 at 16:09 algoProgalgoProg 7302 gold badges12 silver badges30 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 1

In the paper it is Sum(a[i], (i, 1, n)). If you create an indexed variable you can differentiate with respect to it:

from sympy import *
from sympy.abc import n, gamma, epsilon
a = IndexedBase('a')
# Define the RMS normalization equation
mean_square = Sum(a[i]**2, (i, 1, n)) / n
rms = sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms

>>> print(str(rms.diff(a[i])))
Sum(2*a[i], (i, 1, n))/(2*n*sqrt(epsilon + Sum(a[i]**2, (i, 1, n))/n))

cf here

发布评论

评论列表(0)

  1. 暂无评论