I'm currently using broadcasting, given a 3D pytorch tensor of shape 32 x 15000 by (128*batch_size) (huge, I know. Cannot shrink this) to generate a histogram for each 32x15000 array. I am storing lowers and uppers for each to easily bin using broadcasting.
Is there a torch.histogram alternative that will let me do this instead? broadcasting consumes huge amounts of memory as the mask ends up being 32 x 15000 x (128batch_sizebins) and I cannot fit this in VRAM.
Any suggestions would be welcome. I'm currently getting around the issue by iteratively handling each bin, but that leads to massive speed losses.