I've seen several questions about in-place operations regarding their efficiency, but I'm actually more confused about the inner-workings of pytorch.
Let's take a simple example like
a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)
x1 = a * b
x2 = x1 * c
In this case, things are easy. Backpropagations happens like this:
x2.grad <- 1
c.grad <- x2.grad * x1 = x1 = a * b
x1.grad <- x2.grad * c = c
b.grad <- x1.grad * a = c * a
a.grad <- x1.grad * b = c * b
Everything works correctly. However, in this scenario we have are allocating two buffers: x1
and x2
. Now, what happen when we do something like this:
x = a * b
x = x * c
It seems to me that the overall expression is the same. However, if we try to compute the gradients the same way we did before, we will run into the following problem:
x.grad <- 1
c.grad <- x.grad * x = x = a * b * c
Uh oh. we already got a mistake. Since we performed the multiplication with c in place, we lost the buffer containing a * b
which was needed in order to calculate the gradient of c
.
I can imagine two possible solutions to this:
The previous code actually gets 'compiled' into something like 'x = a * b * c'. But I feel like this kind of optimization might fail when we try more complicated expressions.
The previous code actually uses intermediate buffers (like
x1
) instead.
However, in this case, what happens if we try to compute something like.
x = a * b
x *= c
x *= d
x *= e
x *= f
Does pytorch actually create 3 temporary buffers (x1
, x2
and x3
)?
How is this kind of problem solved in modern frameworks?
I've seen several questions about in-place operations regarding their efficiency, but I'm actually more confused about the inner-workings of pytorch.
Let's take a simple example like
a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)
x1 = a * b
x2 = x1 * c
In this case, things are easy. Backpropagations happens like this:
x2.grad <- 1
c.grad <- x2.grad * x1 = x1 = a * b
x1.grad <- x2.grad * c = c
b.grad <- x1.grad * a = c * a
a.grad <- x1.grad * b = c * b
Everything works correctly. However, in this scenario we have are allocating two buffers: x1
and x2
. Now, what happen when we do something like this:
x = a * b
x = x * c
It seems to me that the overall expression is the same. However, if we try to compute the gradients the same way we did before, we will run into the following problem:
x.grad <- 1
c.grad <- x.grad * x = x = a * b * c
Uh oh. we already got a mistake. Since we performed the multiplication with c in place, we lost the buffer containing a * b
which was needed in order to calculate the gradient of c
.
I can imagine two possible solutions to this:
The previous code actually gets 'compiled' into something like 'x = a * b * c'. But I feel like this kind of optimization might fail when we try more complicated expressions.
The previous code actually uses intermediate buffers (like
x1
) instead.
However, in this case, what happens if we try to compute something like.
x = a * b
x *= c
x *= d
x *= e
x *= f
Does pytorch actually create 3 temporary buffers (x1
, x2
and x3
)?
How is this kind of problem solved in modern frameworks?
Share Improve this question edited Feb 15 at 15:01 Victor Chavauty asked Feb 15 at 13:38 Victor ChavautyVictor Chavauty 1963 silver badges12 bronze badges1 Answer
Reset to default 2You cannot backprop through an in-place operation. However, there are cases where it can appear possible because the in-place operation happened in a way that did not impact the gradient calculation.
First lets distinguish between what is happening in python vs pytorch. Pytorch operations return new tensor objects. Your example of using the x
variable twice is not a problem as x
in this case is pointing to two different tensors. You can check this with the data_ptr
attribute:
a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)
x = a * b
print(x.data_ptr())
> 605338432 # `x` points to 605338432
x = x * c
print(x.data_ptr())
> 605338624 # `x` points to 605338624, a different object
Now to the pytorch level. When you compute values, pytorch saves only the necessary values for backprop. If one of these values is modified by an in-place operation, you get an error. You can however do in-place operations to values not required for gradient calculation.
Consider y = exp(x)
. In this case, dy/dx = exp(x) = y
. For backprop, pytorch stores the value of y
to use when the gradient of x
is computed. This means that if you modify y
with an in-place operation and try to backprop, you'll get an error:
x = torch.randn(10, requires_grad=True)
y = x.exp() # computing `y=exp(x)` saves `y` for backprop
y.add_(10) # in-place addition
y.backward(torch.ones_like(y)) # backprop throws error
> RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
Now consider y = x ** 2
. In this case, dy/dx = 2 * x
. For backprop, pytorch stores the value of x
. In this case, y
is not involved with the gradient computation. This means you can modify y
with an in-place operation.
x = torch.randn(10, requires_grad=True)
y = x.pow(2) # computing `y = x**2` saves `x` for backprop
y.add_(10) # in-place addition
y.backward(torch.ones_like(y)) # backprop runs without error
In general you should avoid using in-place operations for anything you want to backprop through.