In my framework, I am having an outer loop (here mocked by the variable n
) and inside the loop body I have to perform matrix inversions/multiplications for multiple batch dimensions. I observed that manually looping over the batch dimensions and computing the inverse is measurable faster than passing all batches to the torch.linalg.inv
function.
Similar statements can be made for computing matrix multiplications using torch.einsum
function.
My expectation was that passing all batches at once performs better than for-loops. Any ideas/explanations/recommendations here?
Profiling output of function inverse_batch
:
ncalls tottime percall cumtime percall filename:lineno(function)
100 2.112 0.021 2.112 0.021 {built-in method torch._C._linalg.linalg_inv}
1 0.000 0.000 2.112 2.112 mwe.py:5(inverse_batch)
1 0.000 0.000 2.112 2.112 {built-in method builtins.exec}
1 0.000 0.000 2.112 2.112 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Profiling output of function inverse_loop
:
ncalls tottime percall cumtime percall filename:lineno(function)
8000 0.207 0.000 0.207 0.000 {built-in method torch._C._linalg.linalg_inv}
1 0.022 0.022 0.229 0.229 mwe.py:9(inverse_loop)
1 0.000 0.000 0.000 0.000 {method 'view' of 'torch._C.TensorBase' objects}
1 0.000 0.000 0.229 0.229 {built-in method builtins.exec}
1 0.000 0.000 0.229 0.229 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Code:
import torch
import cProfile
import pstats
def inverse_batch(tensors, n):
for i in range(n):
torch.linalg.inv(tensors)
def inverse_loop(tensors, n):
tensors = tensors.view(-1, 3, 3)
for i in range(n):
for j in range(10 * 8):
torch.linalg.inv(tensors[j])
# Create a batch of tensors
tensors = torch.randn(10, 8, 3, 3, dtype = torch.double) # Shape: (10, 8, 3, 3)
# Profile code
n = 100 # Dummy outer loop variable
cProfile.run('inverse_batch(tensors, n)', 'profile_output')
stats = pstats.Stats('profile_output')
stats.strip_dirs().sort_stats('tottime').print_stats()
In my framework, I am having an outer loop (here mocked by the variable n
) and inside the loop body I have to perform matrix inversions/multiplications for multiple batch dimensions. I observed that manually looping over the batch dimensions and computing the inverse is measurable faster than passing all batches to the torch.linalg.inv
function.
Similar statements can be made for computing matrix multiplications using torch.einsum
function.
My expectation was that passing all batches at once performs better than for-loops. Any ideas/explanations/recommendations here?
Profiling output of function inverse_batch
:
ncalls tottime percall cumtime percall filename:lineno(function)
100 2.112 0.021 2.112 0.021 {built-in method torch._C._linalg.linalg_inv}
1 0.000 0.000 2.112 2.112 mwe.py:5(inverse_batch)
1 0.000 0.000 2.112 2.112 {built-in method builtins.exec}
1 0.000 0.000 2.112 2.112 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Profiling output of function inverse_loop
:
ncalls tottime percall cumtime percall filename:lineno(function)
8000 0.207 0.000 0.207 0.000 {built-in method torch._C._linalg.linalg_inv}
1 0.022 0.022 0.229 0.229 mwe.py:9(inverse_loop)
1 0.000 0.000 0.000 0.000 {method 'view' of 'torch._C.TensorBase' objects}
1 0.000 0.000 0.229 0.229 {built-in method builtins.exec}
1 0.000 0.000 0.229 0.229 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Code:
import torch
import cProfile
import pstats
def inverse_batch(tensors, n):
for i in range(n):
torch.linalg.inv(tensors)
def inverse_loop(tensors, n):
tensors = tensors.view(-1, 3, 3)
for i in range(n):
for j in range(10 * 8):
torch.linalg.inv(tensors[j])
# Create a batch of tensors
tensors = torch.randn(10, 8, 3, 3, dtype = torch.double) # Shape: (10, 8, 3, 3)
# Profile code
n = 100 # Dummy outer loop variable
cProfile.run('inverse_batch(tensors, n)', 'profile_output')
stats = pstats.Stats('profile_output')
stats.strip_dirs().sort_stats('tottime').print_stats()
Share
Improve this question
edited Feb 6 at 13:57
Mathieu
asked Feb 6 at 13:02
MathieuMathieu
3287 bronze badges
3
|
2 Answers
Reset to default 2So I’ve been poking around and I think what is going on is this: PyTorch simplifies the problem of solving multiple matrices by reshaping the tensor into a single square matrix, and then solving for the entire tensor, before reshaping back. From the source code comments:
/*
The idea is to reduce the problem to 2D square matrix inversion.
Step 1. Calculate the shape of the result and the shape of the intermediate 2D matrix.
Step 2. Reshape `self` to 2D matrix.
Step 3. Invert the 2D matrix self.to_2D()
There is no quick way to find out whether the matrix is invertible,
so at this stage an error from at::inverse can be thrown.
Note that for CUDA this causes cross-device memory synchronization that can be slow.
Step 4. reshape the result.
*/
This sounds like it would be faster, but the time complexity for solving for the inverse scales with the cube of the matrix size, i.e. O(N3). By “breaking down” the tensor yourself through looping, you save PyTorch the trouble of having to solve this big tensor problem by encoding implicitly the information for byte-size (pun intended) tensors.
Smarter people than I can swoop in and correct this if I’m off base, but that’s what I’ve managed to find for you. Hope it helps.
ETA: You could test this pretty simply by running these same reshaping-inverse-reshape steps on your tensor in python and timing that. I’d suspect it will end up being very similar in time to the non-looped method.
Use vmap to map a vector of inputs to the inverse function:
batch_inv = torch.vmap(torch.linalg.inv)
Then view your matrices as a single vector of size [n,.,.]
and pass to batch_inv
.
original_shape = tensors.shape
tensors = tensors.view(-1,3,3)
output = batch_inv(tensors)
output = output.view(original_shape)
inverse_batch(tensors, 100)
andinverse_loop(tensors, 100)
100 times each on GPU.inverse_batch
takes an average of 7.3ms whileinverse_loop
takes an average of 691ms. – Karl Commented Feb 6 at 23:04inverse_batch
slower thaninverse_loop
). Or does pytorch need a lot of free cpu resources to make the batch wise computations effectively? – Mathieu Commented Feb 7 at 7:22inverse_batch
is faster thaninverse_loop
). Pytorch uses multiple cores for CPU computation by default, so other processes may have impacted the benchmark – Karl Commented Feb 7 at 17:37