I am a new user of the Mamba model, I read a few papers saying it has a robust performance on image segmentation tasks. If someone has implemented it before, is there any guidance on the correct setup for the mamba block or input image patches and dimensions that could lead to the best results?
My implementation so far has not shown any advantage of adding the Mamba block into my code, here is a small snippet of my implementation:
x = torch.rand(1, 16, 256, 256)
norm = RMSNorm(16**2)
mamba = Mamba(16**2)
_, c, h, _ = x.shape
x = rearrange(x, 'b c (p1 ph) (p2 pw) -> b (c p1 p2) (ph pw)', ph=16, pw=16)
x = mamba(norm(x)) + x
x = rearrange(x, 'b (c p) d -> b c p d', c=c)
x = rearrange(x, 'b c (p1 p2) (ph pw) -> b c (p1 ph) (p2 pw)', p1=h//16, ph=16)